Adding upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
768d386bf5
commit
fca0265317
87 changed files with 7994 additions and 421 deletions
2
.github/workflows/python-package.yml
vendored
2
.github/workflows/python-package.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r dev-requirements.txt
|
||||
- name: Run checks (linter, code style, tests)
|
||||
run: |
|
||||
./run_checks.sh
|
||||
|
|
15
CHANGELOG.md
15
CHANGELOG.md
|
@ -1,6 +1,21 @@
|
|||
Changelog
|
||||
=========
|
||||
|
||||
v9.0.0
|
||||
------
|
||||
|
||||
Changes:
|
||||
|
||||
- Breaking : Changed AST hierarchy of exp.Table with exp.Alias. Before Tables were children's of their aliases, but in order to simplify the AST and fix some issues, Tables now have an alias property.
|
||||
|
||||
v8.0.0
|
||||
------
|
||||
|
||||
Changes:
|
||||
|
||||
- Breaking : New add\_table method in Schema ABC.
|
||||
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
|
||||
|
||||
v7.1.0
|
||||
------
|
||||
|
||||
|
|
|
@ -317,6 +317,7 @@ Dialect["custom"]
|
|||
## Run Tests and Lint
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
# set `SKIP_INTEGRATION=1` to skip integration tests
|
||||
./run_checks.sh
|
||||
```
|
||||
|
||||
|
|
|
@ -2,5 +2,7 @@ autoflake
|
|||
black
|
||||
duckdb
|
||||
isort
|
||||
mypy
|
||||
pandas
|
||||
pyspark
|
||||
python-dateutil
|
|
@ -11,4 +11,5 @@ python -m autoflake -i -r ${RETURN_ERROR_CODE} \
|
|||
sqlglot/ tests/
|
||||
python -m isort --profile black sqlglot/ tests/
|
||||
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
|
||||
python -m mypy sqlglot tests
|
||||
python -m unittest
|
||||
|
|
15
setup.cfg
Normal file
15
setup.cfg
Normal file
|
@ -0,0 +1,15 @@
|
|||
[mypy]
|
||||
disallow_untyped_calls = False
|
||||
no_implicit_optional = True
|
||||
|
||||
[mypy-sqlglot.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-sqlglot.dataframe.*]
|
||||
ignore_errors = False
|
||||
|
||||
[mypy-tests.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-tests.dataframe.*]
|
||||
ignore_errors = False
|
|
@ -21,12 +21,15 @@ from sqlglot.expressions import table_ as table
|
|||
from sqlglot.expressions import union
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "7.1.3"
|
||||
__version__ = "9.0.1"
|
||||
|
||||
pretty = False
|
||||
|
||||
schema = MappingSchema()
|
||||
|
||||
|
||||
def parse(sql, read=None, **opts):
|
||||
"""
|
||||
|
|
|
@ -40,8 +40,8 @@ parser.add_argument(
|
|||
"--error-level",
|
||||
dest="error_level",
|
||||
type=str,
|
||||
default="RAISE",
|
||||
help="IGNORE, WARN, RAISE (default)",
|
||||
default="IMMEDIATE",
|
||||
help="IGNORE, WARN, RAISE, IMMEDIATE (default)",
|
||||
)
|
||||
|
||||
|
||||
|
|
224
sqlglot/dataframe/README.md
Normal file
224
sqlglot/dataframe/README.md
Normal file
|
@ -0,0 +1,224 @@
|
|||
# PySpark DataFrame SQL Generator
|
||||
|
||||
This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
|
||||
|
||||
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
|
||||
|
||||
# How to use
|
||||
|
||||
## Instructions
|
||||
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
|
||||
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
|
||||
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
|
||||
* The column structure can be defined the following ways:
|
||||
* Dictionary where the keys are column names and values are string of the Spark SQL type name
|
||||
* Ex: {'cola': 'string', 'colb': 'int'}
|
||||
* PySpark DataFrame `StructType` similar to when using `createDataFrame`
|
||||
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
|
||||
* A string of names and types similar to what is supported in `createDataFrame`
|
||||
* Ex: `cola: STRING, colb: INT`
|
||||
* [Not Recommended] A list of string column names without type
|
||||
* Ex: ['cola', 'colb']
|
||||
* The lack of types may limit functionality in future releases
|
||||
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
|
||||
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
|
||||
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
|
||||
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
|
||||
* Ex: `.sql(pretty=True, dialect='bigquery')`
|
||||
|
||||
## Examples
|
||||
|
||||
```python
|
||||
import sqlglot
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
|
||||
sqlglot.schema.add_table('employee', {
|
||||
'employee_id': 'INT',
|
||||
'fname': 'STRING',
|
||||
'lname': 'STRING',
|
||||
'age': 'INT',
|
||||
}) # Register the table structure prior to reading from the table
|
||||
|
||||
spark = SparkSession()
|
||||
|
||||
df = (
|
||||
spark
|
||||
.table('employee')
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
)
|
||||
|
||||
print(df.sql(pretty=True)) # Spark will be the dialect used by default
|
||||
```
|
||||
Output:
|
||||
```sparksql
|
||||
SELECT
|
||||
`employee`.`age` AS `age`,
|
||||
COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees`
|
||||
FROM `employee` AS `employee`
|
||||
GROUP BY
|
||||
`employee`.`age`
|
||||
```
|
||||
|
||||
## Registering Custom Schema Class
|
||||
|
||||
The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`.
|
||||
|
||||
```python
|
||||
import sqlglot
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
|
||||
class ExternalSchema(Schema):
|
||||
...
|
||||
|
||||
sqlglot.schema = ExternalSchema()
|
||||
|
||||
spark = SparkSession()
|
||||
|
||||
df = (
|
||||
spark
|
||||
.table('employee')
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
)
|
||||
|
||||
print(df.sql(pretty=True))
|
||||
```
|
||||
|
||||
## Example Implementations
|
||||
|
||||
### Bigquery
|
||||
```python
|
||||
from google.cloud import bigquery
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import types
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
|
||||
client = bigquery.Client()
|
||||
|
||||
data = [
|
||||
(1, "Jack", "Shephard", 34),
|
||||
(2, "John", "Locke", 48),
|
||||
(3, "Kate", "Austen", 34),
|
||||
(4, "Claire", "Littleton", 22),
|
||||
(5, "Hugo", "Reyes", 26),
|
||||
]
|
||||
schema = types.StructType([
|
||||
types.StructField('employee_id', types.IntegerType(), False),
|
||||
types.StructField('fname', types.StringType(), False),
|
||||
types.StructField('lname', types.StringType(), False),
|
||||
types.StructField('age', types.IntegerType(), False),
|
||||
])
|
||||
|
||||
sql_statements = (
|
||||
SparkSession()
|
||||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
.sql(dialect="bigquery")
|
||||
)
|
||||
|
||||
result = None
|
||||
for sql in sql_statements:
|
||||
result = client.query(sql)
|
||||
|
||||
assert result is not None
|
||||
for row in client.query(result):
|
||||
print(f"Age: {row['age']}, Num Employees: {row['num_employees']}")
|
||||
```
|
||||
|
||||
### Snowflake
|
||||
```python
|
||||
import os
|
||||
|
||||
import snowflake.connector
|
||||
from sqlglot.dataframe.session import SparkSession
|
||||
from sqlglot.dataframe import types
|
||||
from sqlglot.dataframe import functions as F
|
||||
|
||||
ctx = snowflake.connector.connect(
|
||||
user=os.environ["SNOWFLAKE_USER"],
|
||||
password=os.environ["SNOWFLAKE_PASS"],
|
||||
account=os.environ["SNOWFLAKE_ACCOUNT"]
|
||||
)
|
||||
cs = ctx.cursor()
|
||||
|
||||
data = [
|
||||
(1, "Jack", "Shephard", 34),
|
||||
(2, "John", "Locke", 48),
|
||||
(3, "Kate", "Austen", 34),
|
||||
(4, "Claire", "Littleton", 22),
|
||||
(5, "Hugo", "Reyes", 26),
|
||||
]
|
||||
schema = types.StructType([
|
||||
types.StructField('employee_id', types.IntegerType(), False),
|
||||
types.StructField('fname', types.StringType(), False),
|
||||
types.StructField('lname', types.StringType(), False),
|
||||
types.StructField('age', types.IntegerType(), False),
|
||||
])
|
||||
|
||||
sql_statements = (
|
||||
SparkSession()
|
||||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("lname")).alias("num_employees"))
|
||||
.sql(dialect="snowflake")
|
||||
)
|
||||
|
||||
try:
|
||||
for sql in sql_statements:
|
||||
cs.execute(sql)
|
||||
results = cs.fetchall()
|
||||
for row in results:
|
||||
print(f"Age: {row[0]}, Num Employees: {row[1]}")
|
||||
finally:
|
||||
cs.close()
|
||||
ctx.close()
|
||||
```
|
||||
|
||||
### Spark
|
||||
```python
|
||||
from pyspark.sql.session import SparkSession as PySparkSession
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import types
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
|
||||
data = [
|
||||
(1, "Jack", "Shephard", 34),
|
||||
(2, "John", "Locke", 48),
|
||||
(3, "Kate", "Austen", 34),
|
||||
(4, "Claire", "Littleton", 22),
|
||||
(5, "Hugo", "Reyes", 26),
|
||||
]
|
||||
schema = types.StructType([
|
||||
types.StructField('employee_id', types.IntegerType(), False),
|
||||
types.StructField('fname', types.StringType(), False),
|
||||
types.StructField('lname', types.StringType(), False),
|
||||
types.StructField('age', types.IntegerType(), False),
|
||||
])
|
||||
|
||||
sql_statements = (
|
||||
SparkSession()
|
||||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
.sql(dialect="bigquery")
|
||||
)
|
||||
|
||||
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
|
||||
|
||||
df = None
|
||||
for sql in sql_statements:
|
||||
df = pyspark.sql(sql)
|
||||
|
||||
assert df is not None
|
||||
df.show()
|
||||
```
|
||||
|
||||
# Unsupportable Operations
|
||||
|
||||
Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported.
|
0
sqlglot/dataframe/__init__.py
Normal file
0
sqlglot/dataframe/__init__.py
Normal file
18
sqlglot/dataframe/sql/__init__.py
Normal file
18
sqlglot/dataframe/sql/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
|
||||
from sqlglot.dataframe.sql.group import GroupedData
|
||||
from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql.window import Window, WindowSpec
|
||||
|
||||
__all__ = [
|
||||
"SparkSession",
|
||||
"DataFrame",
|
||||
"GroupedData",
|
||||
"Column",
|
||||
"DataFrameNaFunctions",
|
||||
"Window",
|
||||
"WindowSpec",
|
||||
"DataFrameReader",
|
||||
"DataFrameWriter",
|
||||
]
|
20
sqlglot/dataframe/sql/_typing.pyi
Normal file
20
sqlglot/dataframe/sql/_typing.pyi
Normal file
|
@ -0,0 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
|
||||
ColumnLiterals = t.TypeVar(
|
||||
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
)
|
||||
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
|
||||
ColumnOrLiteral = t.TypeVar(
|
||||
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
)
|
||||
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
|
||||
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
|
295
sqlglot/dataframe/sql/column.py
Normal file
295
sqlglot/dataframe/sql/column.py
Normal file
|
@ -0,0 +1,295 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.types import DataType
|
||||
from sqlglot.helper import flatten
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnOrLiteral
|
||||
from sqlglot.dataframe.sql.window import WindowSpec
|
||||
|
||||
|
||||
class Column:
|
||||
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
|
||||
if isinstance(expression, Column):
|
||||
expression = expression.expression # type: ignore
|
||||
elif expression is None or not isinstance(expression, (str, exp.Expression)):
|
||||
expression = self._lit(expression).expression # type: ignore
|
||||
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.expression)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.expression)
|
||||
|
||||
def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
|
||||
return self.binary_op(exp.EQ, other)
|
||||
|
||||
def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
|
||||
return self.binary_op(exp.NEQ, other)
|
||||
|
||||
def __gt__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.GT, other)
|
||||
|
||||
def __ge__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.GTE, other)
|
||||
|
||||
def __lt__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.LT, other)
|
||||
|
||||
def __le__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.LTE, other)
|
||||
|
||||
def __and__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.And, other)
|
||||
|
||||
def __or__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Or, other)
|
||||
|
||||
def __mod__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Mod, other)
|
||||
|
||||
def __add__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Add, other)
|
||||
|
||||
def __sub__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Sub, other)
|
||||
|
||||
def __mul__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Mul, other)
|
||||
|
||||
def __truediv__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Div, other)
|
||||
|
||||
def __div__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.binary_op(exp.Div, other)
|
||||
|
||||
def __neg__(self) -> Column:
|
||||
return self.unary_op(exp.Neg)
|
||||
|
||||
def __radd__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Add, other)
|
||||
|
||||
def __rsub__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Sub, other)
|
||||
|
||||
def __rmul__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Mul, other)
|
||||
|
||||
def __rdiv__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Div, other)
|
||||
|
||||
def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Div, other)
|
||||
|
||||
def __rmod__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Mod, other)
|
||||
|
||||
def __pow__(self, power: ColumnOrLiteral, modulo=None):
|
||||
return Column(exp.Pow(this=self.expression, power=Column(power).expression))
|
||||
|
||||
def __rpow__(self, power: ColumnOrLiteral):
|
||||
return Column(exp.Pow(this=Column(power).expression, power=self.expression))
|
||||
|
||||
def __invert__(self):
|
||||
return self.unary_op(exp.Not)
|
||||
|
||||
def __rand__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.And, other)
|
||||
|
||||
def __ror__(self, other: ColumnOrLiteral) -> Column:
|
||||
return self.inverse_binary_op(exp.Or, other)
|
||||
|
||||
@classmethod
|
||||
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
|
||||
return cls(value)
|
||||
|
||||
@classmethod
|
||||
def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
|
||||
return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
|
||||
|
||||
@classmethod
|
||||
def _lit(cls, value: ColumnOrLiteral) -> Column:
|
||||
if isinstance(value, dict):
|
||||
columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
|
||||
return cls(exp.Struct(expressions=columns))
|
||||
return cls(exp.convert(value))
|
||||
|
||||
@classmethod
|
||||
def invoke_anonymous_function(
|
||||
cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
|
||||
) -> Column:
|
||||
columns = [] if column is None else [cls.ensure_col(column)]
|
||||
column_args = [cls.ensure_col(arg) for arg in args]
|
||||
expressions = [x.expression for x in columns + column_args]
|
||||
new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
|
||||
return Column(new_expression)
|
||||
|
||||
@classmethod
|
||||
def invoke_expression_over_column(
|
||||
cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
|
||||
) -> Column:
|
||||
ensured_column = None if column is None else cls.ensure_col(column)
|
||||
new_expression = (
|
||||
callable_expression(**kwargs)
|
||||
if ensured_column is None
|
||||
else callable_expression(this=ensured_column.column_expression, **kwargs)
|
||||
)
|
||||
return Column(new_expression)
|
||||
|
||||
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
|
||||
|
||||
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
|
||||
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
|
||||
|
||||
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
|
||||
return Column(klass(this=self.column_expression, **kwargs))
|
||||
|
||||
@property
|
||||
def is_alias(self):
|
||||
return isinstance(self.expression, exp.Alias)
|
||||
|
||||
@property
|
||||
def is_column(self):
|
||||
return isinstance(self.expression, exp.Column)
|
||||
|
||||
@property
|
||||
def column_expression(self) -> exp.Column:
|
||||
return self.expression.unalias()
|
||||
|
||||
@property
|
||||
def alias_or_name(self) -> str:
|
||||
return self.expression.alias_or_name
|
||||
|
||||
@classmethod
|
||||
def ensure_literal(cls, value) -> Column:
|
||||
from sqlglot.dataframe.sql.functions import lit
|
||||
|
||||
if isinstance(value, cls):
|
||||
value = value.expression
|
||||
if not isinstance(value, exp.Literal):
|
||||
return lit(value)
|
||||
return Column(value)
|
||||
|
||||
def copy(self) -> Column:
|
||||
return Column(self.expression.copy())
|
||||
|
||||
def set_table_name(self, table_name: str, copy=False) -> Column:
|
||||
expression = self.expression.copy() if copy else self.expression
|
||||
expression.set("table", exp.to_identifier(table_name))
|
||||
return Column(expression)
|
||||
|
||||
def sql(self, **kwargs) -> Column:
|
||||
return self.expression.sql(**{"dialect": "spark", **kwargs})
|
||||
|
||||
def alias(self, name: str) -> Column:
|
||||
new_expression = exp.alias_(self.column_expression, name)
|
||||
return Column(new_expression)
|
||||
|
||||
def asc(self) -> Column:
|
||||
new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
|
||||
return Column(new_expression)
|
||||
|
||||
def desc(self) -> Column:
|
||||
new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
|
||||
return Column(new_expression)
|
||||
|
||||
asc_nulls_first = asc
|
||||
|
||||
def asc_nulls_last(self) -> Column:
|
||||
new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
|
||||
return Column(new_expression)
|
||||
|
||||
def desc_nulls_first(self) -> Column:
|
||||
new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
|
||||
return Column(new_expression)
|
||||
|
||||
desc_nulls_last = desc
|
||||
|
||||
def when(self, condition: Column, value: t.Any) -> Column:
|
||||
from sqlglot.dataframe.sql.functions import when
|
||||
|
||||
column_with_if = when(condition, value)
|
||||
if not isinstance(self.expression, exp.Case):
|
||||
return column_with_if
|
||||
new_column = self.copy()
|
||||
new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
|
||||
return new_column
|
||||
|
||||
def otherwise(self, value: t.Any) -> Column:
|
||||
from sqlglot.dataframe.sql.functions import lit
|
||||
|
||||
true_value = value if isinstance(value, Column) else lit(value)
|
||||
new_column = self.copy()
|
||||
new_column.expression.set("default", true_value.column_expression)
|
||||
return new_column
|
||||
|
||||
def isNull(self) -> Column:
|
||||
new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
|
||||
return Column(new_expression)
|
||||
|
||||
def isNotNull(self) -> Column:
|
||||
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
|
||||
return Column(new_expression)
|
||||
|
||||
def cast(self, dataType: t.Union[str, DataType]):
|
||||
"""
|
||||
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
|
||||
Sqlglot doesn't currently replicate this class so it only accepts a string
|
||||
"""
|
||||
if isinstance(dataType, DataType):
|
||||
dataType = dataType.simpleString()
|
||||
new_expression = exp.Cast(this=self.column_expression, to=dataType)
|
||||
return Column(new_expression)
|
||||
|
||||
def startswith(self, value: t.Union[str, Column]) -> Column:
|
||||
value = self._lit(value) if not isinstance(value, Column) else value
|
||||
return self.invoke_anonymous_function(self, "STARTSWITH", value)
|
||||
|
||||
def endswith(self, value: t.Union[str, Column]) -> Column:
|
||||
value = self._lit(value) if not isinstance(value, Column) else value
|
||||
return self.invoke_anonymous_function(self, "ENDSWITH", value)
|
||||
|
||||
def rlike(self, regexp: str) -> Column:
|
||||
return self.invoke_expression_over_column(
|
||||
column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
|
||||
)
|
||||
|
||||
def like(self, other: str):
|
||||
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
|
||||
|
||||
def ilike(self, other: str):
|
||||
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
|
||||
|
||||
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
|
||||
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
|
||||
length = self._lit(length) if not isinstance(length, Column) else length
|
||||
return Column.invoke_expression_over_column(
|
||||
self, exp.Substring, start=startPos.expression, length=length.expression
|
||||
)
|
||||
|
||||
def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
|
||||
columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
|
||||
expressions = [self._lit(x).expression for x in columns]
|
||||
return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
|
||||
|
||||
def between(
|
||||
self,
|
||||
lowerBound: t.Union[ColumnOrLiteral],
|
||||
upperBound: t.Union[ColumnOrLiteral],
|
||||
) -> Column:
|
||||
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
|
||||
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
|
||||
return Column(
|
||||
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
|
||||
)
|
||||
|
||||
def over(self, window: WindowSpec) -> Column:
|
||||
window_expression = window.expression.copy()
|
||||
window_expression.set("this", self.column_expression)
|
||||
return Column(window_expression)
|
730
sqlglot/dataframe/sql/dataframe.py
Normal file
730
sqlglot/dataframe/sql/dataframe.py
Normal file
|
@ -0,0 +1,730 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import typing as t
|
||||
import zlib
|
||||
from copy import copy
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.group import GroupedData
|
||||
from sqlglot.dataframe.sql.normalize import normalize
|
||||
from sqlglot.dataframe.sql.operations import Operation, operation
|
||||
from sqlglot.dataframe.sql.readwriter import DataFrameWriter
|
||||
from sqlglot.dataframe.sql.transforms import replace_id_value
|
||||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.dataframe.sql.window import Window
|
||||
from sqlglot.helper import ensure_list, object_to_dict
|
||||
from sqlglot.optimizer import optimize as optimize_func
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
|
||||
JOIN_HINTS = {
|
||||
"BROADCAST",
|
||||
"BROADCASTJOIN",
|
||||
"MAPJOIN",
|
||||
"MERGE",
|
||||
"SHUFFLEMERGE",
|
||||
"MERGEJOIN",
|
||||
"SHUFFLE_HASH",
|
||||
"SHUFFLE_REPLICATE_NL",
|
||||
}
|
||||
|
||||
|
||||
class DataFrame:
|
||||
def __init__(
|
||||
self,
|
||||
spark: SparkSession,
|
||||
expression: exp.Select,
|
||||
branch_id: t.Optional[str] = None,
|
||||
sequence_id: t.Optional[str] = None,
|
||||
last_op: Operation = Operation.INIT,
|
||||
pending_hints: t.Optional[t.List[exp.Expression]] = None,
|
||||
output_expression_container: t.Optional[OutputExpressionContainer] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.spark = spark
|
||||
self.expression = expression
|
||||
self.branch_id = branch_id or self.spark._random_branch_id
|
||||
self.sequence_id = sequence_id or self.spark._random_sequence_id
|
||||
self.last_op = last_op
|
||||
self.pending_hints = pending_hints or []
|
||||
self.output_expression_container = output_expression_container or exp.Select()
|
||||
|
||||
def __getattr__(self, column_name: str) -> Column:
|
||||
return self[column_name]
|
||||
|
||||
def __getitem__(self, column_name: str) -> Column:
|
||||
column_name = f"{self.branch_id}.{column_name}"
|
||||
return Column(column_name)
|
||||
|
||||
def __copy__(self):
|
||||
return self.copy()
|
||||
|
||||
@property
|
||||
def sparkSession(self):
|
||||
return self.spark
|
||||
|
||||
@property
|
||||
def write(self):
|
||||
return DataFrameWriter(self)
|
||||
|
||||
@property
|
||||
def latest_cte_name(self) -> str:
|
||||
if not self.expression.ctes:
|
||||
from_exp = self.expression.args["from"]
|
||||
if from_exp.alias_or_name:
|
||||
return from_exp.alias_or_name
|
||||
table_alias = from_exp.find(exp.TableAlias)
|
||||
if not table_alias:
|
||||
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
|
||||
return table_alias.alias_or_name
|
||||
return self.expression.ctes[-1].alias
|
||||
|
||||
@property
|
||||
def pending_join_hints(self):
|
||||
return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
|
||||
|
||||
@property
|
||||
def pending_partition_hints(self):
|
||||
return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
|
||||
|
||||
@property
|
||||
def columns(self) -> t.List[str]:
|
||||
return self.expression.named_selects
|
||||
|
||||
@property
|
||||
def na(self) -> DataFrameNaFunctions:
|
||||
return DataFrameNaFunctions(self)
|
||||
|
||||
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
||||
expression = expression.copy()
|
||||
ctes = expression.ctes
|
||||
replacement_mapping = {}
|
||||
for cte in ctes:
|
||||
old_name_id = cte.args["alias"].this
|
||||
new_hashed_id = exp.to_identifier(
|
||||
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
||||
)
|
||||
replacement_mapping[old_name_id] = new_hashed_id
|
||||
cte.set("alias", exp.TableAlias(this=new_hashed_id))
|
||||
expression = expression.transform(replace_id_value, replacement_mapping)
|
||||
return expression
|
||||
|
||||
def _create_cte_from_expression(
|
||||
self,
|
||||
expression: exp.Expression,
|
||||
branch_id: t.Optional[str] = None,
|
||||
sequence_id: t.Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> t.Tuple[exp.CTE, str]:
|
||||
name = self.spark._random_name
|
||||
expression_to_cte = expression.copy()
|
||||
expression_to_cte.set("with", None)
|
||||
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
|
||||
cte.set("branch_id", branch_id or self.branch_id)
|
||||
cte.set("sequence_id", sequence_id or self.sequence_id)
|
||||
return cte, name
|
||||
|
||||
def _ensure_list_of_columns(
|
||||
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
|
||||
) -> t.List[Column]:
|
||||
columns = ensure_list(cols)
|
||||
columns = Column.ensure_cols(columns)
|
||||
return columns
|
||||
|
||||
def _ensure_and_normalize_cols(self, cols):
|
||||
cols = self._ensure_list_of_columns(cols)
|
||||
normalize(self.spark, self.expression, cols)
|
||||
return cols
|
||||
|
||||
def _ensure_and_normalize_col(self, col):
|
||||
col = Column.ensure_col(col)
|
||||
normalize(self.spark, self.expression, col)
|
||||
return col
|
||||
|
||||
def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
|
||||
df = self._resolve_pending_hints()
|
||||
sequence_id = sequence_id or df.sequence_id
|
||||
expression = df.expression.copy()
|
||||
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
|
||||
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
|
||||
sel_columns = df._get_outer_select_columns(cte_expression)
|
||||
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
|
||||
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
||||
|
||||
def _resolve_pending_hints(self) -> DataFrame:
|
||||
df = self.copy()
|
||||
if not self.pending_hints:
|
||||
return df
|
||||
expression = df.expression
|
||||
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
||||
for hint in df.pending_partition_hints:
|
||||
hint_expression.args.get("expressions").append(hint)
|
||||
df.pending_hints.remove(hint)
|
||||
|
||||
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
|
||||
if join_aliases:
|
||||
for hint in df.pending_join_hints:
|
||||
for sequence_id_expression in hint.expressions:
|
||||
sequence_id_or_name = sequence_id_expression.alias_or_name
|
||||
sequence_ids_to_match = [sequence_id_or_name]
|
||||
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
||||
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
|
||||
matching_ctes = [
|
||||
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
|
||||
]
|
||||
for matching_cte in matching_ctes:
|
||||
if matching_cte.alias_or_name in join_aliases:
|
||||
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
||||
df.pending_hints.remove(hint)
|
||||
break
|
||||
hint_expression.args.get("expressions").append(hint)
|
||||
if hint_expression.expressions:
|
||||
expression.set("hint", hint_expression)
|
||||
return df
|
||||
|
||||
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
||||
hint_name = hint_name.upper()
|
||||
hint_expression = (
|
||||
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
|
||||
if hint_name in JOIN_HINTS
|
||||
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
|
||||
)
|
||||
new_df = self.copy()
|
||||
new_df.pending_hints.append(hint_expression)
|
||||
return new_df
|
||||
|
||||
def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
|
||||
other_df = other._convert_leaf_to_cte()
|
||||
base_expression = self.expression.copy()
|
||||
base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
|
||||
all_ctes = base_expression.ctes
|
||||
other_df.expression.set("with", None)
|
||||
base_expression.set("with", None)
|
||||
operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
|
||||
operation.set("with", exp.With(expressions=all_ctes))
|
||||
return self.copy(expression=operation)._convert_leaf_to_cte()
|
||||
|
||||
def _cache(self, storage_level: str):
|
||||
df = self._convert_leaf_to_cte()
|
||||
df.expression.ctes[-1].set("cache_storage_level", storage_level)
|
||||
return df
|
||||
|
||||
@classmethod
|
||||
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
|
||||
expression = expression.copy()
|
||||
with_expression = expression.args.get("with")
|
||||
if with_expression:
|
||||
existing_ctes = with_expression.expressions
|
||||
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
|
||||
for cte in ctes:
|
||||
if cte.alias_or_name not in existsing_cte_names:
|
||||
existing_ctes.append(cte)
|
||||
else:
|
||||
existing_ctes = ctes
|
||||
expression.set("with", exp.With(expressions=existing_ctes))
|
||||
return expression
|
||||
|
||||
@classmethod
|
||||
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
|
||||
expression = item.expression if isinstance(item, DataFrame) else item
|
||||
return [Column(x) for x in expression.find(exp.Select).expressions]
|
||||
|
||||
@classmethod
|
||||
def _create_hash_from_expression(cls, expression: exp.Select):
|
||||
value = expression.sql(dialect="spark").encode("utf-8")
|
||||
return f"t{zlib.crc32(value)}"[:6]
|
||||
|
||||
def _get_select_expressions(
|
||||
self,
|
||||
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
||||
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
|
||||
main_select_ctes: t.List[exp.CTE] = []
|
||||
for cte in self.expression.ctes:
|
||||
cache_storage_level = cte.args.get("cache_storage_level")
|
||||
if cache_storage_level:
|
||||
select_expression = cte.this.copy()
|
||||
select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
|
||||
select_expression.set("cte_alias_name", cte.alias_or_name)
|
||||
select_expression.set("cache_storage_level", cache_storage_level)
|
||||
select_expressions.append((exp.Cache, select_expression))
|
||||
else:
|
||||
main_select_ctes.append(cte)
|
||||
main_select = self.expression.copy()
|
||||
if main_select_ctes:
|
||||
main_select.set("with", exp.With(expressions=main_select_ctes))
|
||||
expression_select_pair = (type(self.output_expression_container), main_select)
|
||||
select_expressions.append(expression_select_pair) # type: ignore
|
||||
return select_expressions
|
||||
|
||||
def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
|
||||
df = self._resolve_pending_hints()
|
||||
select_expressions = df._get_select_expressions()
|
||||
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
||||
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
||||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
if optimize:
|
||||
select_expression = optimize_func(select_expression)
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
||||
if expression_type == exp.Cache:
|
||||
cache_table_name = df._create_hash_from_expression(select_expression)
|
||||
cache_table = exp.to_table(cache_table_name)
|
||||
original_alias_name = select_expression.args["cte_alias_name"]
|
||||
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
|
||||
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
|
||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||
options = [
|
||||
exp.Literal.string("storageLevel"),
|
||||
exp.Literal.string(cache_storage_level),
|
||||
]
|
||||
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
|
||||
# We will drop the "view" if it exists before running the cache table
|
||||
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
||||
elif expression_type == exp.Create:
|
||||
expression = df.output_expression_container.copy()
|
||||
expression.set("expression", select_expression)
|
||||
elif expression_type == exp.Insert:
|
||||
expression = df.output_expression_container.copy()
|
||||
select_without_ctes = select_expression.copy()
|
||||
select_without_ctes.set("with", None)
|
||||
expression.set("expression", select_without_ctes)
|
||||
if select_expression.ctes:
|
||||
expression.set("with", exp.With(expressions=select_expression.ctes))
|
||||
elif expression_type == exp.Select:
|
||||
expression = select_expression
|
||||
else:
|
||||
raise ValueError(f"Invalid expression type: {expression_type}")
|
||||
output_expressions.append(expression)
|
||||
|
||||
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
|
||||
|
||||
def copy(self, **kwargs) -> DataFrame:
|
||||
return DataFrame(**object_to_dict(self, **kwargs))
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def select(self, *cols, **kwargs) -> DataFrame:
|
||||
cols = self._ensure_and_normalize_cols(cols)
|
||||
kwargs["append"] = kwargs.get("append", False)
|
||||
if self.expression.args.get("joins"):
|
||||
ambiguous_cols = [col for col in cols if not col.column_expression.table]
|
||||
if ambiguous_cols:
|
||||
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
|
||||
cte_names_in_join = [x.this for x in join_table_identifiers]
|
||||
for ambiguous_col in ambiguous_cols:
|
||||
ctes_with_column = [
|
||||
cte
|
||||
for cte in self.expression.ctes
|
||||
if cte.alias_or_name in cte_names_in_join
|
||||
and ambiguous_col.alias_or_name in cte.this.named_selects
|
||||
]
|
||||
# If the select column does not specify a table and there is a join
|
||||
# then we assume they are referring to the left table
|
||||
if len(ctes_with_column) > 1:
|
||||
table_identifier = self.expression.args["from"].args["expressions"][0].this
|
||||
else:
|
||||
table_identifier = ctes_with_column[0].args["alias"].this
|
||||
ambiguous_col.expression.set("table", table_identifier)
|
||||
expression = self.expression.select(*[x.expression for x in cols], **kwargs)
|
||||
qualify_columns(expression, sqlglot.schema)
|
||||
return self.copy(expression=expression, **kwargs)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def alias(self, name: str, **kwargs) -> DataFrame:
|
||||
new_sequence_id = self.spark._random_sequence_id
|
||||
df = self.copy()
|
||||
for join_hint in df.pending_join_hints:
|
||||
for expression in join_hint.expressions:
|
||||
if expression.alias_or_name == self.sequence_id:
|
||||
expression.set("this", Column.ensure_col(new_sequence_id).expression)
|
||||
df.spark._add_alias_to_mapping(name, new_sequence_id)
|
||||
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
|
||||
|
||||
@operation(Operation.WHERE)
|
||||
def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
|
||||
col = self._ensure_and_normalize_col(column)
|
||||
return self.copy(expression=self.expression.where(col.expression))
|
||||
|
||||
filter = where
|
||||
|
||||
@operation(Operation.GROUP_BY)
|
||||
def groupBy(self, *cols, **kwargs) -> GroupedData:
|
||||
columns = self._ensure_and_normalize_cols(cols)
|
||||
return GroupedData(self, columns, self.last_op)
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def agg(self, *exprs, **kwargs) -> DataFrame:
|
||||
cols = self._ensure_and_normalize_cols(exprs)
|
||||
return self.groupBy().agg(*cols)
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def join(
|
||||
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
|
||||
) -> DataFrame:
|
||||
other_df = other_df._convert_leaf_to_cte()
|
||||
pre_join_self_latest_cte_name = self.latest_cte_name
|
||||
columns = self._ensure_and_normalize_cols(on)
|
||||
join_type = how.replace("_", " ")
|
||||
if isinstance(columns[0].expression, exp.Column):
|
||||
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
|
||||
join_clause = functools.reduce(
|
||||
lambda x, y: x & y,
|
||||
[
|
||||
col.copy().set_table_name(pre_join_self_latest_cte_name)
|
||||
== col.copy().set_table_name(other_df.latest_cte_name)
|
||||
for col in columns
|
||||
],
|
||||
)
|
||||
else:
|
||||
if len(columns) > 1:
|
||||
columns = [functools.reduce(lambda x, y: x & y, columns)]
|
||||
join_clause = columns[0]
|
||||
join_columns = [
|
||||
Column(x).set_table_name(pre_join_self_latest_cte_name)
|
||||
if i % 2 == 0
|
||||
else Column(x).set_table_name(other_df.latest_cte_name)
|
||||
for i, x in enumerate(join_clause.expression.find_all(exp.Column))
|
||||
]
|
||||
self_columns = [
|
||||
column.set_table_name(pre_join_self_latest_cte_name, copy=True)
|
||||
for column in self._get_outer_select_columns(self)
|
||||
]
|
||||
other_columns = [
|
||||
column.set_table_name(other_df.latest_cte_name, copy=True)
|
||||
for column in self._get_outer_select_columns(other_df)
|
||||
]
|
||||
column_value_mapping = {
|
||||
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
|
||||
for column in other_columns + self_columns + join_columns
|
||||
}
|
||||
all_columns = [
|
||||
column_value_mapping[name]
|
||||
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
|
||||
]
|
||||
new_df = self.copy(
|
||||
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
|
||||
)
|
||||
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
|
||||
new_df.pending_hints.extend(other_df.pending_hints)
|
||||
new_df = new_df.select.__wrapped__(new_df, *all_columns)
|
||||
return new_df
|
||||
|
||||
@operation(Operation.ORDER_BY)
|
||||
def orderBy(
|
||||
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
|
||||
) -> DataFrame:
|
||||
"""
|
||||
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
||||
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
|
||||
is unlikely to come up.
|
||||
"""
|
||||
columns = self._ensure_and_normalize_cols(cols)
|
||||
pre_ordered_col_indexes = [
|
||||
x
|
||||
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
|
||||
if x is not None
|
||||
]
|
||||
if ascending is None:
|
||||
ascending = [True] * len(columns)
|
||||
elif not isinstance(ascending, list):
|
||||
ascending = [ascending] * len(columns)
|
||||
ascending = [bool(x) for i, x in enumerate(ascending)]
|
||||
assert len(columns) == len(
|
||||
ascending
|
||||
), "The length of items in ascending must equal the number of columns provided"
|
||||
col_and_ascending = list(zip(columns, ascending))
|
||||
order_by_columns = [
|
||||
exp.Ordered(this=col.expression, desc=not asc)
|
||||
if i not in pre_ordered_col_indexes
|
||||
else columns[i].column_expression
|
||||
for i, (col, asc) in enumerate(col_and_ascending)
|
||||
]
|
||||
return self.copy(expression=self.expression.order_by(*order_by_columns))
|
||||
|
||||
sort = orderBy
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def union(self, other: DataFrame) -> DataFrame:
|
||||
return self._set_operation(exp.Union, other, False)
|
||||
|
||||
unionAll = union
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
|
||||
l_columns = self.columns
|
||||
r_columns = other.columns
|
||||
if not allowMissingColumns:
|
||||
l_expressions = l_columns
|
||||
r_expressions = l_columns
|
||||
else:
|
||||
l_expressions = []
|
||||
r_expressions = []
|
||||
r_columns_unused = copy(r_columns)
|
||||
for l_column in l_columns:
|
||||
l_expressions.append(l_column)
|
||||
if l_column in r_columns:
|
||||
r_expressions.append(l_column)
|
||||
r_columns_unused.remove(l_column)
|
||||
else:
|
||||
r_expressions.append(exp.alias_(exp.Null(), l_column))
|
||||
for r_column in r_columns_unused:
|
||||
l_expressions.append(exp.alias_(exp.Null(), r_column))
|
||||
r_expressions.append(r_column)
|
||||
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
||||
l_df = self.copy()
|
||||
if allowMissingColumns:
|
||||
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
||||
return l_df._set_operation(exp.Union, r_df, False)
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def intersect(self, other: DataFrame) -> DataFrame:
|
||||
return self._set_operation(exp.Intersect, other, True)
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def intersectAll(self, other: DataFrame) -> DataFrame:
|
||||
return self._set_operation(exp.Intersect, other, False)
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def exceptAll(self, other: DataFrame) -> DataFrame:
|
||||
return self._set_operation(exp.Except, other, False)
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def distinct(self) -> DataFrame:
|
||||
return self.copy(expression=self.expression.distinct())
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
|
||||
if not subset:
|
||||
return self.distinct()
|
||||
column_names = ensure_list(subset)
|
||||
window = Window.partitionBy(*column_names).orderBy(*column_names)
|
||||
return (
|
||||
self.copy()
|
||||
.withColumn("row_num", F.row_number().over(window))
|
||||
.where(F.col("row_num") == F.lit(1))
|
||||
.drop("row_num")
|
||||
)
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def dropna(
|
||||
self,
|
||||
how: str = "any",
|
||||
thresh: t.Optional[int] = None,
|
||||
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
minimum_non_null = thresh or 0 # will be determined later if thresh is null
|
||||
new_df = self.copy()
|
||||
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||
if subset:
|
||||
null_check_columns = self._ensure_and_normalize_cols(subset)
|
||||
else:
|
||||
null_check_columns = all_columns
|
||||
if thresh is None:
|
||||
minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
|
||||
else:
|
||||
minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
|
||||
if minimum_num_nulls > len(null_check_columns):
|
||||
raise RuntimeError(
|
||||
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
||||
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
||||
)
|
||||
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
|
||||
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
||||
num_nulls = nulls_added_together.alias("num_nulls")
|
||||
new_df = new_df.select(num_nulls, append=True)
|
||||
filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
|
||||
final_df = filtered_df.select(*all_columns)
|
||||
return final_df
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def fillna(
|
||||
self,
|
||||
value: t.Union[ColumnLiterals],
|
||||
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Functionality Difference: If you provide a value to replace a null and that type conflicts
|
||||
with the type of the column then PySpark will just ignore your replacement.
|
||||
This will try to cast them to be the same in some cases. So they won't always match.
|
||||
Best to not mix types so make sure replacement is the same type as the column
|
||||
|
||||
Possibility for improvement: Use `typeof` function to get the type of the column
|
||||
and check if it matches the type of the value provided. If not then make it null.
|
||||
"""
|
||||
from sqlglot.dataframe.sql.functions import lit
|
||||
|
||||
values = None
|
||||
columns = None
|
||||
new_df = self.copy()
|
||||
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
||||
if isinstance(value, dict):
|
||||
values = value.values()
|
||||
columns = self._ensure_and_normalize_cols(list(value))
|
||||
if not columns:
|
||||
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
||||
if not values:
|
||||
values = [value] * len(columns)
|
||||
value_columns = [lit(value) for value in values]
|
||||
|
||||
null_replacement_mapping = {
|
||||
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
|
||||
for column, value in zip(columns, value_columns)
|
||||
}
|
||||
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
||||
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
|
||||
new_df = new_df.select(*null_replacement_columns)
|
||||
return new_df
|
||||
|
||||
@operation(Operation.FROM)
|
||||
def replace(
|
||||
self,
|
||||
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
||||
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
||||
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.functions import lit
|
||||
|
||||
old_values = None
|
||||
subset = ensure_list(subset)
|
||||
new_df = self.copy()
|
||||
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
||||
|
||||
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
||||
if isinstance(to_replace, dict):
|
||||
old_values = list(to_replace)
|
||||
new_values = list(to_replace.values())
|
||||
elif not old_values and isinstance(to_replace, list):
|
||||
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
||||
assert len(to_replace) == len(value), "the replacements and values must be the same length"
|
||||
old_values = to_replace
|
||||
new_values = value
|
||||
else:
|
||||
old_values = [to_replace] * len(columns)
|
||||
new_values = [value] * len(columns)
|
||||
old_values = [lit(value) for value in old_values]
|
||||
new_values = [lit(value) for value in new_values]
|
||||
|
||||
replacement_mapping = {}
|
||||
for column in columns:
|
||||
expression = Column(None)
|
||||
for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
|
||||
if i == 0:
|
||||
expression = F.when(column == old_value, new_value)
|
||||
else:
|
||||
expression = expression.when(column == old_value, new_value) # type: ignore
|
||||
replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
|
||||
column.expression.alias_or_name
|
||||
)
|
||||
|
||||
replacement_mapping = {**all_column_mapping, **replacement_mapping}
|
||||
replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
|
||||
new_df = new_df.select(*replacement_columns)
|
||||
return new_df
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
||||
col = self._ensure_and_normalize_col(col)
|
||||
existing_col_names = self.expression.named_selects
|
||||
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
|
||||
if existing_col_index:
|
||||
expression = self.expression.copy()
|
||||
expression.expressions[existing_col_index] = col.expression
|
||||
return self.copy(expression=expression)
|
||||
return self.copy().select(col.alias(colName), append=True)
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def withColumnRenamed(self, existing: str, new: str):
|
||||
expression = self.expression.copy()
|
||||
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
|
||||
if not existing_columns:
|
||||
raise ValueError("Tried to rename a column that doesn't exist")
|
||||
for existing_column in existing_columns:
|
||||
if isinstance(existing_column, exp.Column):
|
||||
existing_column.replace(exp.alias_(existing_column.copy(), new))
|
||||
else:
|
||||
existing_column.set("alias", exp.to_identifier(new))
|
||||
return self.copy(expression=expression)
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
|
||||
all_columns = self._get_outer_select_columns(self.expression)
|
||||
drop_cols = self._ensure_and_normalize_cols(cols)
|
||||
new_columns = [
|
||||
col
|
||||
for col in all_columns
|
||||
if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
|
||||
]
|
||||
return self.copy().select(*new_columns, append=False)
|
||||
|
||||
@operation(Operation.LIMIT)
|
||||
def limit(self, num: int) -> DataFrame:
|
||||
return self.copy(expression=self.expression.limit(num))
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
||||
parameter_list = ensure_list(parameters)
|
||||
parameter_columns = (
|
||||
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
|
||||
)
|
||||
return self._hint(name, parameter_columns)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
|
||||
num_partitions = Column.ensure_cols(ensure_list(numPartitions))
|
||||
columns = self._ensure_and_normalize_cols(cols)
|
||||
args = num_partitions + columns
|
||||
return self._hint("repartition", args)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def coalesce(self, numPartitions: int) -> DataFrame:
|
||||
num_partitions = Column.ensure_cols([numPartitions])
|
||||
return self._hint("coalesce", num_partitions)
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def cache(self) -> DataFrame:
|
||||
return self._cache(storage_level="MEMORY_AND_DISK")
|
||||
|
||||
@operation(Operation.NO_OP)
|
||||
def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
|
||||
"""
|
||||
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
|
||||
"""
|
||||
return self._cache(storageLevel)
|
||||
|
||||
|
||||
class DataFrameNaFunctions:
|
||||
def __init__(self, df: DataFrame):
|
||||
self.df = df
|
||||
|
||||
def drop(
|
||||
self,
|
||||
how: str = "any",
|
||||
thresh: t.Optional[int] = None,
|
||||
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
return self.df.dropna(how=how, thresh=thresh, subset=subset)
|
||||
|
||||
def fill(
|
||||
self,
|
||||
value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
|
||||
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
return self.df.fillna(value=value, subset=subset)
|
||||
|
||||
def replace(
|
||||
self,
|
||||
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
||||
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
||||
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
||||
) -> DataFrame:
|
||||
return self.df.replace(to_replace=to_replace, value=value, subset=subset)
|
1258
sqlglot/dataframe/sql/functions.py
Normal file
1258
sqlglot/dataframe/sql/functions.py
Normal file
File diff suppressed because it is too large
Load diff
57
sqlglot/dataframe/sql/group.py
Normal file
57
sqlglot/dataframe/sql/group.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.operations import Operation, operation
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
|
||||
class GroupedData:
|
||||
def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
|
||||
self._df = df.copy()
|
||||
self.spark = df.spark
|
||||
self.last_op = last_op
|
||||
self.group_by_cols = group_by_cols
|
||||
|
||||
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
|
||||
func_name = func_name.lower()
|
||||
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
|
||||
|
||||
@operation(Operation.SELECT)
|
||||
def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
|
||||
columns = (
|
||||
[Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
|
||||
if isinstance(exprs[0], dict)
|
||||
else exprs
|
||||
)
|
||||
cols = self._df._ensure_and_normalize_cols(columns)
|
||||
|
||||
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
|
||||
*[x.expression for x in self.group_by_cols + cols], append=False
|
||||
)
|
||||
return self._df.copy(expression=expression)
|
||||
|
||||
def count(self) -> DataFrame:
|
||||
return self.agg(F.count("*").alias("count"))
|
||||
|
||||
def mean(self, *cols: str) -> DataFrame:
|
||||
return self.avg(*cols)
|
||||
|
||||
def avg(self, *cols: str) -> DataFrame:
|
||||
return self.agg(*self._get_function_applied_columns("avg", cols))
|
||||
|
||||
def max(self, *cols: str) -> DataFrame:
|
||||
return self.agg(*self._get_function_applied_columns("max", cols))
|
||||
|
||||
def min(self, *cols: str) -> DataFrame:
|
||||
return self.agg(*self._get_function_applied_columns("min", cols))
|
||||
|
||||
def sum(self, *cols: str) -> DataFrame:
|
||||
return self.agg(*self._get_function_applied_columns("sum", cols))
|
||||
|
||||
def pivot(self, *cols: str) -> DataFrame:
|
||||
raise NotImplementedError("Sum distinct is not currently implemented")
|
72
sqlglot/dataframe/sql/normalize.py
Normal file
72
sqlglot/dataframe/sql/normalize.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
|
||||
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
|
||||
expr = ensure_list(expr)
|
||||
expressions = _ensure_expressions(expr)
|
||||
for expression in expressions:
|
||||
identifiers = expression.find_all(exp.Identifier)
|
||||
for identifier in identifiers:
|
||||
replace_alias_name_with_cte_name(spark, expression_context, identifier)
|
||||
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
||||
|
||||
|
||||
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
|
||||
if id.alias_or_name in spark.name_to_sequence_id_mapping:
|
||||
for cte in reversed(expression_context.ctes):
|
||||
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
|
||||
_set_alias_name(id, cte.alias_or_name)
|
||||
break
|
||||
|
||||
|
||||
def replace_branch_and_sequence_ids_with_cte_name(
|
||||
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
|
||||
):
|
||||
if id.alias_or_name in spark.known_ids:
|
||||
# Check if we have a join and if both the tables in that join share a common branch id
|
||||
# If so we need to have this reference the left table by default unless the id is a sequence
|
||||
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
||||
# be common in practice
|
||||
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
|
||||
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
|
||||
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
|
||||
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
|
||||
assert len(ctes_in_join) == 2
|
||||
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
||||
return
|
||||
|
||||
for cte in reversed(expression_context.ctes):
|
||||
if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
|
||||
_set_alias_name(id, cte.alias_or_name)
|
||||
return
|
||||
|
||||
|
||||
def _set_alias_name(id: exp.Identifier, name: str):
|
||||
id.set("this", name)
|
||||
|
||||
|
||||
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
|
||||
values = ensure_list(values)
|
||||
results = []
|
||||
for value in values:
|
||||
if isinstance(value, str):
|
||||
results.append(Column.ensure_col(value).expression)
|
||||
elif isinstance(value, Column):
|
||||
results.append(value.expression)
|
||||
elif isinstance(value, exp.Expression):
|
||||
results.append(value)
|
||||
else:
|
||||
raise ValueError(f"Got an invalid type to normalize: {type(value)}")
|
||||
return results
|
53
sqlglot/dataframe/sql/operations.py
Normal file
53
sqlglot/dataframe/sql/operations.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.group import GroupedData
|
||||
|
||||
|
||||
class Operation(IntEnum):
|
||||
INIT = -1
|
||||
NO_OP = 0
|
||||
FROM = 1
|
||||
WHERE = 2
|
||||
GROUP_BY = 3
|
||||
HAVING = 4
|
||||
SELECT = 5
|
||||
ORDER_BY = 6
|
||||
LIMIT = 7
|
||||
|
||||
|
||||
def operation(op: Operation):
|
||||
"""
|
||||
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
|
||||
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
|
||||
included with the previous operation.
|
||||
|
||||
Ex: After a user does a join we want to allow them to select which columns for the different
|
||||
tables that they want to carry through to the following operation. If we put that join in
|
||||
a CTE preemptively then the user would not have a chance to select which column they want
|
||||
in cases where there is overlap in names.
|
||||
"""
|
||||
|
||||
def decorator(func: t.Callable):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self: DataFrame, *args, **kwargs):
|
||||
if self.last_op == Operation.INIT:
|
||||
self = self._convert_leaf_to_cte()
|
||||
self.last_op = Operation.NO_OP
|
||||
last_op = self.last_op
|
||||
new_op = op if op != Operation.NO_OP else last_op
|
||||
if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
|
||||
self = self._convert_leaf_to_cte()
|
||||
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
|
||||
df.last_op = new_op # type: ignore
|
||||
return df
|
||||
|
||||
wrapper.__wrapped__ = func # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
79
sqlglot/dataframe/sql/readwriter.py
Normal file
79
sqlglot/dataframe/sql/readwriter.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import object_to_dict
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
|
||||
class DataFrameReader:
|
||||
def __init__(self, spark: SparkSession):
|
||||
self.spark = spark
|
||||
|
||||
def table(self, tableName: str) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
sqlglot.schema.add_table(tableName)
|
||||
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
|
||||
|
||||
|
||||
class DataFrameWriter:
|
||||
def __init__(
|
||||
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
|
||||
):
|
||||
self._df = df
|
||||
self._spark = spark or df.spark
|
||||
self._mode = mode
|
||||
self._by_name = by_name
|
||||
|
||||
def copy(self, **kwargs) -> DataFrameWriter:
|
||||
return DataFrameWriter(
|
||||
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
|
||||
)
|
||||
|
||||
def sql(self, **kwargs) -> t.List[str]:
|
||||
return self._df.sql(**kwargs)
|
||||
|
||||
def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
|
||||
return self.copy(_mode=saveMode)
|
||||
|
||||
@property
|
||||
def byName(self):
|
||||
return self.copy(by_name=True)
|
||||
|
||||
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
|
||||
output_expression_container = exp.Insert(
|
||||
**{
|
||||
"this": exp.to_table(tableName),
|
||||
"overwrite": overwrite,
|
||||
}
|
||||
)
|
||||
df = self._df.copy(output_expression_container=output_expression_container)
|
||||
if self._by_name:
|
||||
columns = sqlglot.schema.column_names(tableName, only_visible=True)
|
||||
df = df._convert_leaf_to_cte().select(*columns)
|
||||
|
||||
return self.copy(_df=df)
|
||||
|
||||
def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
|
||||
if format is not None:
|
||||
raise NotImplementedError("Providing Format in the save as table is not supported")
|
||||
exists, replace, mode = None, None, mode or str(self._mode)
|
||||
if mode == "append":
|
||||
return self.insertInto(name)
|
||||
if mode == "ignore":
|
||||
exists = True
|
||||
if mode == "overwrite":
|
||||
replace = True
|
||||
output_expression_container = exp.Create(
|
||||
this=exp.to_table(name),
|
||||
kind="TABLE",
|
||||
exists=exists,
|
||||
replace=replace,
|
||||
)
|
||||
return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
|
148
sqlglot/dataframe/sql/session.py
Normal file
148
sqlglot/dataframe/sql/session.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.readwriter import DataFrameReader
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
|
||||
|
||||
|
||||
class SparkSession:
|
||||
known_ids: t.ClassVar[t.Set[str]] = set()
|
||||
known_branch_ids: t.ClassVar[t.Set[str]] = set()
|
||||
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
|
||||
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
|
||||
|
||||
def __init__(self):
|
||||
self.incrementing_id = 1
|
||||
|
||||
def __getattr__(self, name: str) -> SparkSession:
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> SparkSession:
|
||||
return self
|
||||
|
||||
@property
|
||||
def read(self) -> DataFrameReader:
|
||||
return DataFrameReader(self)
|
||||
|
||||
def table(self, tableName: str) -> DataFrame:
|
||||
return self.read.table(tableName)
|
||||
|
||||
def createDataFrame(
|
||||
self,
|
||||
data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
|
||||
schema: t.Optional[SchemaInput] = None,
|
||||
samplingRatio: t.Optional[float] = None,
|
||||
verifySchema: bool = False,
|
||||
) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
|
||||
if samplingRatio is not None or verifySchema:
|
||||
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
|
||||
if schema is not None and (
|
||||
not isinstance(schema, (StructType, str, list))
|
||||
or (isinstance(schema, list) and not isinstance(schema[0], str))
|
||||
):
|
||||
raise NotImplementedError("Only schema of either list or string of list supported")
|
||||
if not data:
|
||||
raise ValueError("Must provide data to create into a DataFrame")
|
||||
|
||||
column_mapping: t.Dict[str, t.Optional[str]]
|
||||
if schema is not None:
|
||||
column_mapping = get_column_mapping_from_schema_input(schema)
|
||||
elif isinstance(data[0], dict):
|
||||
column_mapping = {col_name.strip(): None for col_name in data[0]}
|
||||
else:
|
||||
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
|
||||
|
||||
data_expressions = [
|
||||
exp.Tuple(
|
||||
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
|
||||
)
|
||||
for row in data
|
||||
]
|
||||
|
||||
sel_columns = [
|
||||
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
|
||||
for name, data_type in column_mapping.items()
|
||||
]
|
||||
|
||||
select_kwargs = {
|
||||
"expressions": sel_columns,
|
||||
"from": exp.From(
|
||||
expressions=[
|
||||
exp.Subquery(
|
||||
this=exp.Values(expressions=data_expressions),
|
||||
alias=exp.TableAlias(
|
||||
this=exp.to_identifier(self._auto_incrementing_name),
|
||||
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
|
||||
),
|
||||
)
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
sel_expression = exp.Select(**select_kwargs)
|
||||
return DataFrame(self, sel_expression)
|
||||
|
||||
def sql(self, sqlQuery: str) -> DataFrame:
|
||||
expression = sqlglot.parse_one(sqlQuery, read="spark")
|
||||
if isinstance(expression, exp.Select):
|
||||
df = DataFrame(self, expression)
|
||||
df = df._convert_leaf_to_cte()
|
||||
elif isinstance(expression, (exp.Create, exp.Insert)):
|
||||
select_expression = expression.expression.copy()
|
||||
if isinstance(expression, exp.Insert):
|
||||
select_expression.set("with", expression.args.get("with"))
|
||||
expression.set("with", None)
|
||||
del expression.args["expression"]
|
||||
df = DataFrame(self, select_expression, output_expression_container=expression)
|
||||
df = df._convert_leaf_to_cte()
|
||||
else:
|
||||
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
|
||||
return df
|
||||
|
||||
@property
|
||||
def _auto_incrementing_name(self) -> str:
|
||||
name = f"a{self.incrementing_id}"
|
||||
self.incrementing_id += 1
|
||||
return name
|
||||
|
||||
@property
|
||||
def _random_name(self) -> str:
|
||||
return f"a{str(uuid.uuid4())[:8]}"
|
||||
|
||||
@property
|
||||
def _random_branch_id(self) -> str:
|
||||
id = self._random_id
|
||||
self.known_branch_ids.add(id)
|
||||
return id
|
||||
|
||||
@property
|
||||
def _random_sequence_id(self):
|
||||
id = self._random_id
|
||||
self.known_sequence_ids.add(id)
|
||||
return id
|
||||
|
||||
@property
|
||||
def _random_id(self) -> str:
|
||||
id = f"a{str(uuid.uuid4())[:8]}"
|
||||
self.known_ids.add(id)
|
||||
return id
|
||||
|
||||
@property
|
||||
def _join_hint_names(self) -> t.Set[str]:
|
||||
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
|
||||
|
||||
def _add_alias_to_mapping(self, name: str, sequence_id: str):
|
||||
self.name_to_sequence_id_mapping[name].append(sequence_id)
|
9
sqlglot/dataframe/sql/transforms.py
Normal file
9
sqlglot/dataframe/sql/transforms.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
||||
|
||||
def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
|
||||
if isinstance(node, exp.Identifier) and node in replacement_mapping:
|
||||
node = node.replace(replacement_mapping[node].copy())
|
||||
return node
|
208
sqlglot/dataframe/sql/types.py
Normal file
208
sqlglot/dataframe/sql/types.py
Normal file
|
@ -0,0 +1,208 @@
|
|||
import typing as t
|
||||
|
||||
|
||||
class DataType:
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
|
||||
|
||||
def __ne__(self, other: t.Any) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.typeName()
|
||||
|
||||
@classmethod
|
||||
def typeName(cls) -> str:
|
||||
return cls.__name__[:-4].lower()
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
|
||||
return str(self)
|
||||
|
||||
|
||||
class DataTypeWithLength(DataType):
|
||||
def __init__(self, length: int):
|
||||
self.length = length
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.length})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.typeName()}({self.length})"
|
||||
|
||||
|
||||
class StringType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class CharType(DataTypeWithLength):
|
||||
pass
|
||||
|
||||
|
||||
class VarcharType(DataTypeWithLength):
|
||||
pass
|
||||
|
||||
|
||||
class BinaryType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class BooleanType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class DateType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class TimestampType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class TimestampNTZType(DataType):
|
||||
@classmethod
|
||||
def typeName(cls) -> str:
|
||||
return "timestamp_ntz"
|
||||
|
||||
|
||||
class DecimalType(DataType):
|
||||
def __init__(self, precision: int = 10, scale: int = 0):
|
||||
self.precision = precision
|
||||
self.scale = scale
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return f"decimal({self.precision}, {self.scale})"
|
||||
|
||||
def jsonValue(self) -> str:
|
||||
return f"decimal({self.precision}, {self.scale})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DecimalType({self.precision}, {self.scale})"
|
||||
|
||||
|
||||
class DoubleType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class FloatType(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class ByteType(DataType):
|
||||
def __str__(self) -> str:
|
||||
return "tinyint"
|
||||
|
||||
|
||||
class IntegerType(DataType):
|
||||
def __str__(self) -> str:
|
||||
return "int"
|
||||
|
||||
|
||||
class LongType(DataType):
|
||||
def __str__(self) -> str:
|
||||
return "bigint"
|
||||
|
||||
|
||||
class ShortType(DataType):
|
||||
def __str__(self) -> str:
|
||||
return "smallint"
|
||||
|
||||
|
||||
class ArrayType(DataType):
|
||||
def __init__(self, elementType: DataType, containsNull: bool = True):
|
||||
self.elementType = elementType
|
||||
self.containsNull = containsNull
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ArrayType({self.elementType, str(self.containsNull)}"
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return f"array<{self.elementType.simpleString()}>"
|
||||
|
||||
def jsonValue(self) -> t.Dict[str, t.Any]:
|
||||
return {
|
||||
"type": self.typeName(),
|
||||
"elementType": self.elementType.jsonValue(),
|
||||
"containsNull": self.containsNull,
|
||||
}
|
||||
|
||||
|
||||
class MapType(DataType):
|
||||
def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
|
||||
self.keyType = keyType
|
||||
self.valueType = valueType
|
||||
self.valueContainsNull = valueContainsNull
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
|
||||
|
||||
def jsonValue(self) -> t.Dict[str, t.Any]:
|
||||
return {
|
||||
"type": self.typeName(),
|
||||
"keyType": self.keyType.jsonValue(),
|
||||
"valueType": self.valueType.jsonValue(),
|
||||
"valueContainsNull": self.valueContainsNull,
|
||||
}
|
||||
|
||||
|
||||
class StructField(DataType):
|
||||
def __init__(
|
||||
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
|
||||
):
|
||||
self.name = name
|
||||
self.dataType = dataType
|
||||
self.nullable = nullable
|
||||
self.metadata = metadata or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return f"{self.name}:{self.dataType.simpleString()}"
|
||||
|
||||
def jsonValue(self) -> t.Dict[str, t.Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.dataType.jsonValue(),
|
||||
"nullable": self.nullable,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
class StructType(DataType):
|
||||
def __init__(self, fields: t.Optional[t.List[StructField]] = None):
|
||||
if not fields:
|
||||
self.fields = []
|
||||
self.names = []
|
||||
else:
|
||||
self.fields = fields
|
||||
self.names = [f.name for f in fields]
|
||||
|
||||
def __iter__(self) -> t.Iterator[StructField]:
|
||||
return iter(self.fields)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.fields)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"StructType({', '.join(str(field) for field in self)})"
|
||||
|
||||
def simpleString(self) -> str:
|
||||
return f"struct<{', '.join(x.simpleString() for x in self)}>"
|
||||
|
||||
def jsonValue(self) -> t.Dict[str, t.Any]:
|
||||
return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
|
||||
|
||||
def fieldNames(self) -> t.List[str]:
|
||||
return list(self.names)
|
32
sqlglot/dataframe/sql/util.py
Normal file
32
sqlglot/dataframe/sql/util.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql import types
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import SchemaInput
|
||||
|
||||
|
||||
def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
elif isinstance(schema, str):
|
||||
col_name_type_strs = [x.strip() for x in schema.split(",")]
|
||||
return {
|
||||
name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
|
||||
for name_type_str in col_name_type_strs
|
||||
}
|
||||
elif isinstance(schema, types.StructType):
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
|
||||
return {x.strip(): None for x in schema} # type: ignore
|
||||
|
||||
|
||||
def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
|
||||
if not expression.args.get("joins"):
|
||||
return []
|
||||
|
||||
left_table = expression.args["from"].args["expressions"][0]
|
||||
other_tables = [join.this for join in expression.args["joins"]]
|
||||
return [left_table] + other_tables
|
117
sqlglot/dataframe/sql/window.py
Normal file
117
sqlglot/dataframe/sql/window.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.helper import flatten
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnOrName
|
||||
|
||||
|
||||
class Window:
|
||||
_JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
|
||||
_JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
|
||||
_PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
|
||||
_FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
|
||||
|
||||
unboundedPreceding: int = _JAVA_MIN_LONG
|
||||
|
||||
unboundedFollowing: int = _JAVA_MAX_LONG
|
||||
|
||||
currentRow: int = 0
|
||||
|
||||
@classmethod
|
||||
def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
|
||||
return WindowSpec().partitionBy(*cols)
|
||||
|
||||
@classmethod
|
||||
def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
|
||||
return WindowSpec().orderBy(*cols)
|
||||
|
||||
@classmethod
|
||||
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
|
||||
return WindowSpec().rowsBetween(start, end)
|
||||
|
||||
@classmethod
|
||||
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
|
||||
return WindowSpec().rangeBetween(start, end)
|
||||
|
||||
|
||||
class WindowSpec:
|
||||
def __init__(self, expression: exp.Expression = exp.Window()):
|
||||
self.expression = expression
|
||||
|
||||
def copy(self):
|
||||
return WindowSpec(self.expression.copy())
|
||||
|
||||
def sql(self, **kwargs) -> str:
|
||||
return self.expression.sql(dialect="spark", **kwargs)
|
||||
|
||||
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
|
||||
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
|
||||
expressions = [Column.ensure_col(x).expression for x in cols]
|
||||
window_spec = self.copy()
|
||||
partition_by_expressions = window_spec.expression.args.get("partition_by", [])
|
||||
partition_by_expressions.extend(expressions)
|
||||
window_spec.expression.set("partition_by", partition_by_expressions)
|
||||
return window_spec
|
||||
|
||||
def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
|
||||
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
|
||||
expressions = [Column.ensure_col(x).expression for x in cols]
|
||||
window_spec = self.copy()
|
||||
if window_spec.expression.args.get("order") is None:
|
||||
window_spec.expression.set("order", exp.Order(expressions=[]))
|
||||
order_by = window_spec.expression.args["order"].expressions
|
||||
order_by.extend(expressions)
|
||||
window_spec.expression.args["order"].set("expressions", order_by)
|
||||
return window_spec
|
||||
|
||||
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
|
||||
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
|
||||
if start == Window.currentRow:
|
||||
kwargs["start"] = "CURRENT ROW"
|
||||
else:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
"start_side": "PRECEDING",
|
||||
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
|
||||
},
|
||||
}
|
||||
if end == Window.currentRow:
|
||||
kwargs["end"] = "CURRENT ROW"
|
||||
else:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**{
|
||||
"end_side": "FOLLOWING",
|
||||
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
|
||||
},
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def rowsBetween(self, start: int, end: int) -> WindowSpec:
|
||||
window_spec = self.copy()
|
||||
spec = self._calc_start_end(start, end)
|
||||
spec["kind"] = "ROWS"
|
||||
window_spec.expression.set(
|
||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
||||
)
|
||||
return window_spec
|
||||
|
||||
def rangeBetween(self, start: int, end: int) -> WindowSpec:
|
||||
window_spec = self.copy()
|
||||
spec = self._calc_start_end(start, end)
|
||||
spec["kind"] = "RANGE"
|
||||
window_spec.expression.set(
|
||||
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
|
||||
)
|
||||
return window_spec
|
|
@ -78,6 +78,16 @@ def _create_sql(self, expression):
|
|||
|
||||
class BigQuery(Dialect):
|
||||
unnest_column_only = True
|
||||
time_mapping = {
|
||||
"%M": "%-M",
|
||||
"%d": "%-d",
|
||||
"%m": "%-m",
|
||||
"%y": "%-y",
|
||||
"%H": "%-H",
|
||||
"%I": "%-I",
|
||||
"%S": "%-S",
|
||||
"%j": "%-j",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
QUOTES = [
|
||||
|
@ -113,6 +123,7 @@ class BigQuery(Dialect):
|
|||
"DATETIME_SUB": _date_add(exp.DatetimeSub),
|
||||
"TIME_SUB": _date_add(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -137,6 +148,7 @@ class BigQuery(Dialect):
|
|||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
|
|
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.helper import flatten, list_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer
|
||||
|
@ -67,6 +67,11 @@ class _Dialect(type):
|
|||
klass.generator_class.TRANSFORMS[
|
||||
exp.HexString
|
||||
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
|
||||
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
|
||||
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
|
||||
klass.generator_class.TRANSFORMS[
|
||||
exp.ByteString
|
||||
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
|
||||
|
||||
return klass
|
||||
|
||||
|
@ -176,11 +181,7 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
def rename_func(name):
|
||||
def _rename(self, expression):
|
||||
args = (
|
||||
expression.expressions
|
||||
if isinstance(expression, exp.Func) and expression.is_var_len_args
|
||||
else expression.args.values()
|
||||
)
|
||||
args = flatten(expression.args.values())
|
||||
return f"{name}({self.format_args(*args)})"
|
||||
|
||||
return _rename
|
||||
|
|
|
@ -121,6 +121,9 @@ class Hive(Dialect):
|
|||
"ss": "%S",
|
||||
"s": "%-S",
|
||||
"S": "%f",
|
||||
"a": "%p",
|
||||
"DD": "%j",
|
||||
"D": "%-j",
|
||||
}
|
||||
|
||||
date_format = "'yyyy-MM-dd'"
|
||||
|
@ -200,6 +203,7 @@ class Hive(Dialect):
|
|||
exp.AnonymousProperty: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
exp.With: no_recursive_cte_sql,
|
||||
|
|
|
@ -97,6 +97,8 @@ class MySQL(Dialect):
|
|||
"%s": "%S",
|
||||
"%S": "%S",
|
||||
"%u": "%W",
|
||||
"%k": "%-H",
|
||||
"%l": "%-I",
|
||||
}
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
|
@ -145,6 +147,9 @@ class MySQL(Dialect):
|
|||
"_TIS620": TokenType.INTRODUCER,
|
||||
"_UCS2": TokenType.INTRODUCER,
|
||||
"_UJIS": TokenType.INTRODUCER,
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
|
||||
"N": TokenType.INTRODUCER,
|
||||
"n": TokenType.INTRODUCER,
|
||||
"_UTF8": TokenType.INTRODUCER,
|
||||
"_UTF16": TokenType.INTRODUCER,
|
||||
"_UTF16LE": TokenType.INTRODUCER,
|
||||
|
|
|
@ -80,17 +80,12 @@ class Oracle(Dialect):
|
|||
sep="",
|
||||
)
|
||||
|
||||
def alias_sql(self, expression):
|
||||
if isinstance(expression.this, exp.Table):
|
||||
to_sql = self.sql(expression, "alias")
|
||||
# oracle does not allow "AS" between table and alias
|
||||
to_sql = f" {to_sql}" if to_sql else ""
|
||||
return f"{self.sql(expression, 'this')}{to_sql}"
|
||||
return super().alias_sql(expression)
|
||||
|
||||
def offset_sql(self, expression):
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def table_sql(self, expression):
|
||||
return super().table_sql(expression, sep=" ")
|
||||
|
||||
class Tokenizer(Tokenizer):
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
|
|
|
@ -163,6 +163,7 @@ class Postgres(Dialect):
|
|||
class Tokenizer(Tokenizer):
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
|
@ -176,6 +177,11 @@ class Postgres(Dialect):
|
|||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"UUID": TokenType.UUID,
|
||||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
SINGLE_TOKENS = {
|
||||
**Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
class Parser(Parser):
|
||||
STRICT_CAST = False
|
||||
|
|
|
@ -172,6 +172,7 @@ class Presto(Dialect):
|
|||
**transforms.UNALIAS_GROUP,
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
exp.ArraySize: rename_func("CARDINALITY"),
|
||||
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
|
|
|
@ -69,6 +69,35 @@ def _unix_to_time(self, expression):
|
|||
raise ValueError("Improper scale for timestamp")
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self):
|
||||
this = self._parse_var() or self._parse_type()
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
|
||||
name = this.name.upper()
|
||||
if name.startswith("EPOCH"):
|
||||
if name.startswith("EPOCH_MILLISECOND"):
|
||||
scale = 10**3
|
||||
elif name.startswith("EPOCH_MICROSECOND"):
|
||||
scale = 10**6
|
||||
elif name.startswith("EPOCH_NANOSECOND"):
|
||||
scale = 10**9
|
||||
else:
|
||||
scale = None
|
||||
|
||||
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
|
||||
to_unix = self.expression(exp.TimeToUnix, this=ts)
|
||||
|
||||
if scale:
|
||||
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
|
||||
|
||||
return to_unix
|
||||
|
||||
return self.expression(exp.Extract, this=this, expression=expression)
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
||||
|
@ -115,7 +144,7 @@ class Snowflake(Dialect):
|
|||
|
||||
FUNCTION_PARSERS = {
|
||||
**Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": lambda self: self._parse_extract(),
|
||||
"DATE_PART": _parse_date_part,
|
||||
}
|
||||
|
||||
FUNC_TOKENS = {
|
||||
|
@ -161,9 +190,11 @@ class Snowflake(Dialect):
|
|||
class Generator(Generator):
|
||||
TRANSFORMS = {
|
||||
**Generator.TRANSFORMS,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import (
|
||||
create_with_partitions_sql,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.helper import list_get
|
||||
from sqlglot.parser import Parser
|
||||
|
@ -98,13 +94,14 @@ class Spark(Hive):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
|
||||
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateTrunc: rename_func("TRUNC"),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time,
|
||||
|
@ -112,6 +109,8 @@ class Spark(Hive):
|
|||
exp.Map: _map_sql,
|
||||
exp.Reduce: rename_func("AGGREGATE"),
|
||||
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
|
||||
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
|
|
|
@ -32,6 +32,11 @@ class TSQL(Dialect):
|
|||
}
|
||||
|
||||
class Parser(Parser):
|
||||
FUNCTIONS = {
|
||||
**Parser.FUNCTIONS,
|
||||
"CHARINDEX": exp.StrPosition.from_arg_list,
|
||||
}
|
||||
|
||||
def _parse_convert(self):
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
|
|
|
@ -19,6 +19,7 @@ ENV = {
|
|||
"datetime": datetime,
|
||||
"locals": locals,
|
||||
"re": re,
|
||||
"bool": bool,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"str": str,
|
||||
|
|
|
@ -80,9 +80,10 @@ class PythonExecutor:
|
|||
source = step.source
|
||||
|
||||
if isinstance(source, exp.Expression):
|
||||
source = source.this.name or source.alias
|
||||
source = source.name or source.alias
|
||||
else:
|
||||
source = step.name
|
||||
|
||||
condition = self.generate(step.condition)
|
||||
projections = self.generate_tuple(step.projections)
|
||||
|
||||
|
@ -121,7 +122,7 @@ class PythonExecutor:
|
|||
source = step.source
|
||||
alias = source.alias
|
||||
|
||||
with csv_reader(source.this) as reader:
|
||||
with csv_reader(source) as reader:
|
||||
columns = next(reader)
|
||||
table = Table(columns)
|
||||
context = self.context({alias: table})
|
||||
|
@ -308,7 +309,7 @@ def _interval_py(self, expression):
|
|||
def _like_py(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
expression = self.sql(expression, "expression")
|
||||
return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})"""
|
||||
return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
|
||||
|
||||
|
||||
def _ordered_py(self, expression):
|
||||
|
@ -330,6 +331,7 @@ class Python(Dialect):
|
|||
exp.Cast: _cast_py,
|
||||
exp.Column: _column_py,
|
||||
exp.EQ: lambda self, e: self.binary(e, "=="),
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
|
||||
exp.Interval: _interval_py,
|
||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||
exp.Like: _like_py,
|
||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.helper import (
|
|||
camel_to_snake_case,
|
||||
ensure_list,
|
||||
list_get,
|
||||
split_num_words,
|
||||
subclasses,
|
||||
)
|
||||
|
||||
|
@ -108,6 +109,8 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
@property
|
||||
def alias_or_name(self):
|
||||
if isinstance(self, Null):
|
||||
return "NULL"
|
||||
return self.alias or self.name
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
@ -659,6 +662,10 @@ class HexString(Condition):
|
|||
pass
|
||||
|
||||
|
||||
class ByteString(Condition):
|
||||
pass
|
||||
|
||||
|
||||
class Column(Condition):
|
||||
arg_types = {"this": True, "table": False}
|
||||
|
||||
|
@ -725,7 +732,7 @@ class Constraint(Expression):
|
|||
|
||||
|
||||
class Delete(Expression):
|
||||
arg_types = {"with": False, "this": True, "where": False}
|
||||
arg_types = {"with": False, "this": True, "using": False, "where": False}
|
||||
|
||||
|
||||
class Drop(Expression):
|
||||
|
@ -1192,6 +1199,7 @@ QUERY_MODIFIERS = {
|
|||
class Table(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"alias": False,
|
||||
"db": False,
|
||||
"catalog": False,
|
||||
"laterals": False,
|
||||
|
@ -1323,6 +1331,7 @@ class Select(Subqueryable):
|
|||
*expressions (str or Expression): the SQL code strings to parse.
|
||||
If a `Group` instance is passed, this is used as-is.
|
||||
If another `Expression` instance is passed, it will be wrapped in a `Group`.
|
||||
If nothing is passed in then a group by is not applied to the expression
|
||||
append (bool): if `True`, add to any existing expressions.
|
||||
Otherwise, this flattens all the `Group` expression into a single expression.
|
||||
dialect (str): the dialect used to parse the input expression.
|
||||
|
@ -1332,6 +1341,8 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
Select: the modified expression.
|
||||
"""
|
||||
if not expressions:
|
||||
return self if not copy else self.copy()
|
||||
return _apply_child_list_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
|
@ -2239,6 +2250,11 @@ class ArrayAny(Func):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class ArrayConcat(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class ArrayContains(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -2570,7 +2586,7 @@ class SortArray(Func):
|
|||
|
||||
|
||||
class Split(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
arg_types = {"this": True, "expression": True, "limit": False}
|
||||
|
||||
|
||||
# Start may be omitted in the case of postgres
|
||||
|
@ -3209,29 +3225,49 @@ def to_identifier(alias, quoted=None):
|
|||
return identifier
|
||||
|
||||
|
||||
def to_table(sql_path, **kwargs):
|
||||
def to_table(sql_path: str, **kwargs) -> Table:
|
||||
"""
|
||||
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
|
||||
Example:
|
||||
>>> to_table('catalog.db.table_name').sql()
|
||||
'catalog.db.table_name'
|
||||
|
||||
If a table is passed in then that table is returned.
|
||||
|
||||
Args:
|
||||
sql_path(str): `[catalog].[schema].[table]` string
|
||||
sql_path(str|Table): `[catalog].[schema].[table]` string
|
||||
Returns:
|
||||
Table: A table expression
|
||||
"""
|
||||
table_parts = sql_path.split(".")
|
||||
catalog, db, table_name = [
|
||||
to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts
|
||||
]
|
||||
if sql_path is None or isinstance(sql_path, Table):
|
||||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
|
||||
|
||||
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
|
||||
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
|
||||
|
||||
|
||||
def to_column(sql_path: str, **kwargs) -> Column:
|
||||
"""
|
||||
Create a column from a `[table].[column]` sql path. Schema is optional.
|
||||
|
||||
If a column is passed in then that column is returned.
|
||||
|
||||
Args:
|
||||
sql_path: `[table].[column]` string
|
||||
Returns:
|
||||
Table: A column expression
|
||||
"""
|
||||
if sql_path is None or isinstance(sql_path, Column):
|
||||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
|
||||
return Column(this=column_name, table=table_name, **kwargs)
|
||||
|
||||
|
||||
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
||||
"""
|
||||
Create an Alias expression.
|
||||
Expample:
|
||||
Example:
|
||||
>>> alias_('foo', 'bar').sql()
|
||||
'foo AS bar'
|
||||
|
||||
|
@ -3249,7 +3285,16 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
|||
"""
|
||||
exp = maybe_parse(expression, dialect=dialect, **opts)
|
||||
alias = to_identifier(alias, quoted=quoted)
|
||||
alias = TableAlias(this=alias) if table else alias
|
||||
|
||||
if table:
|
||||
expression.set("alias", TableAlias(this=alias))
|
||||
return expression
|
||||
|
||||
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
|
||||
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
|
||||
# for the complete Window expression.
|
||||
#
|
||||
# [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls
|
||||
|
||||
if "alias" in exp.arg_types and not isinstance(exp, Window):
|
||||
exp = exp.copy()
|
||||
|
@ -3295,7 +3340,7 @@ def column(col, table=None, quoted=None):
|
|||
)
|
||||
|
||||
|
||||
def table_(table, db=None, catalog=None, quoted=None):
|
||||
def table_(table, db=None, catalog=None, quoted=None, alias=None):
|
||||
"""Build a Table.
|
||||
|
||||
Args:
|
||||
|
@ -3310,6 +3355,7 @@ def table_(table, db=None, catalog=None, quoted=None):
|
|||
this=to_identifier(table, quoted=quoted),
|
||||
db=to_identifier(db, quoted=quoted),
|
||||
catalog=to_identifier(catalog, quoted=quoted),
|
||||
alias=TableAlias(this=to_identifier(alias)) if alias else None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3453,7 +3499,7 @@ def replace_tables(expression, mapping):
|
|||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
|
||||
'SELECT * FROM "c"'
|
||||
'SELECT * FROM c'
|
||||
|
||||
Returns:
|
||||
The mapped expression
|
||||
|
@ -3463,7 +3509,10 @@ def replace_tables(expression, mapping):
|
|||
if isinstance(node, Table):
|
||||
new_name = mapping.get(table_name(node))
|
||||
if new_name:
|
||||
return table_(*reversed(new_name.split(".")), quoted=True)
|
||||
return to_table(
|
||||
new_name,
|
||||
**{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
|
||||
)
|
||||
return node
|
||||
|
||||
return expression.transform(_replace_tables)
|
||||
|
|
|
@ -47,6 +47,8 @@ class Generator:
|
|||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
annotations: Whether or not to show annotations in the SQL.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -116,6 +118,7 @@ class Generator:
|
|||
"_escaped_quote_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
"_annotations",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -141,6 +144,7 @@ class Generator:
|
|||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
annotations=True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -169,6 +173,7 @@ class Generator:
|
|||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
self._annotations = annotations
|
||||
|
||||
def generate(self, expression):
|
||||
"""
|
||||
|
@ -275,7 +280,9 @@ class Generator:
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
|
||||
if self._annotations:
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name}"
|
||||
return self.sql(expression, "expression")
|
||||
|
||||
def uncache_sql(self, expression):
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -423,8 +430,11 @@ class Generator:
|
|||
|
||||
def delete_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
using_sql = (
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
|
||||
)
|
||||
where_sql = self.sql(expression, "where")
|
||||
sql = f"DELETE FROM {this}{where_sql}"
|
||||
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def drop_sql(self, expression):
|
||||
|
@ -571,7 +581,7 @@ class Generator:
|
|||
null = f" NULL DEFINED AS {null}" if null else ""
|
||||
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
|
||||
|
||||
def table_sql(self, expression):
|
||||
def table_sql(self, expression, sep=" AS "):
|
||||
table = ".".join(
|
||||
part
|
||||
for part in [
|
||||
|
@ -582,13 +592,20 @@ class Generator:
|
|||
if part
|
||||
)
|
||||
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
joins = self.expressions(expression, key="joins", sep="")
|
||||
pivots = self.expressions(expression, key="pivots", sep="")
|
||||
return f"{table}{laterals}{joins}{pivots}"
|
||||
|
||||
if alias and pivots:
|
||||
pivots = f"{pivots}{alias}"
|
||||
alias = ""
|
||||
|
||||
return f"{table}{alias}{laterals}{joins}{pivots}"
|
||||
|
||||
def tablesample_sql(self, expression):
|
||||
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
this = self.sql(expression.this, "this")
|
||||
alias = f" AS {self.sql(expression.this, 'alias')}"
|
||||
else:
|
||||
|
@ -1188,7 +1205,7 @@ class Generator:
|
|||
if isinstance(arg_value, list):
|
||||
for value in arg_value:
|
||||
args.append(value)
|
||||
elif arg_value:
|
||||
else:
|
||||
args.append(arg_value)
|
||||
|
||||
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
|
||||
|
|
|
@ -2,7 +2,9 @@ import inspect
|
|||
import logging
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
|
||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||
|
@ -162,3 +164,54 @@ def find_new_name(taken, base):
|
|||
i += 1
|
||||
new = f"{base}_{i}"
|
||||
return new
|
||||
|
||||
|
||||
def object_to_dict(obj, **kwargs):
|
||||
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
|
||||
|
||||
|
||||
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
|
||||
"""
|
||||
Perform a split on a value and return N words as a result with None used for words that don't exist.
|
||||
|
||||
Args:
|
||||
value: The value to be split
|
||||
sep: The value to use to split on
|
||||
min_num_words: The minimum number of words that are going to be in the result
|
||||
fill_from_start: Indicates that if None values should be inserted at the start or end of the list
|
||||
|
||||
Examples:
|
||||
>>> split_num_words("db.table", ".", 3)
|
||||
[None, 'db', 'table']
|
||||
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
|
||||
['db', 'table', None]
|
||||
>>> split_num_words("db.table", ".", 1)
|
||||
['db', 'table']
|
||||
"""
|
||||
words = value.split(sep)
|
||||
if fill_from_start:
|
||||
return [None] * (min_num_words - len(words)) + words
|
||||
return words + [None] * (min_num_words - len(words))
|
||||
|
||||
|
||||
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
|
||||
"""
|
||||
Flattens a list that can contain both iterables and non-iterable elements
|
||||
|
||||
Examples:
|
||||
>>> list(flatten([[1, 2], 3]))
|
||||
[1, 2, 3]
|
||||
>>> list(flatten([1, 2, 3]))
|
||||
[1, 2, 3]
|
||||
|
||||
Args:
|
||||
values: The value to be flattened
|
||||
|
||||
Returns:
|
||||
Yields non-iterable elements (not including str or byte as iterable)
|
||||
"""
|
||||
for value in values:
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
yield from flatten(value)
|
||||
else:
|
||||
yield value
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.schema import Schema
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
|
|
|
@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
|
||||
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
|
|
|
@ -12,18 +12,16 @@ def isolate_table_selects(expression):
|
|||
if not isinstance(source, exp.Table):
|
||||
continue
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
if not source.alias:
|
||||
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
|
||||
|
||||
parent = source.parent
|
||||
|
||||
parent.replace(
|
||||
source.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source, source.name or parent.alias, table=True),
|
||||
alias(source.copy(), source.name or source.alias, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(parent.alias, copy=False)
|
||||
.subquery(source.alias, copy=False)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
inner_select = inner_scope.expression.unnest()
|
||||
from_or_join = table.find_ancestor(exp.From, exp.Join)
|
||||
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
|
||||
node_to_replace = table
|
||||
if isinstance(node_to_replace.parent, exp.Alias):
|
||||
node_to_replace = node_to_replace.parent
|
||||
alias = node_to_replace.alias
|
||||
else:
|
||||
alias = table.name
|
||||
alias = table.alias_or_name
|
||||
|
||||
_rename_inner_sources(outer_scope, inner_scope, alias)
|
||||
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
|
||||
_merge_from(outer_scope, inner_scope, table, alias)
|
||||
_merge_expressions(outer_scope, inner_scope, alias)
|
||||
_merge_joins(outer_scope, inner_scope, from_or_join)
|
||||
_merge_where(outer_scope, inner_scope, from_or_join)
|
||||
|
@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
|
||||
if isinstance(source, exp.Subquery):
|
||||
source.set("alias", exp.TableAlias(this=new_alias))
|
||||
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
|
||||
source.parent.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table) and source.alias:
|
||||
source.set("alias", new_alias)
|
||||
elif isinstance(source, exp.Table):
|
||||
source.replace(exp.alias_(source.copy(), new_alias))
|
||||
|
||||
|
@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
tables = join_hint.find_all(exp.Table)
|
||||
for table in tables:
|
||||
if table.alias_or_name == node_to_replace.alias_or_name:
|
||||
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
|
||||
table.set("this", exp.to_identifier(new_table.alias_or_name))
|
||||
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
|
||||
outer_scope.remove_source(alias)
|
||||
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sqlglot
|
||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||
|
@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
|
||||
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
|
||||
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
|
||||
rules (list): sequence of optimizer rules to use
|
||||
|
@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
|
|||
Returns:
|
||||
sqlglot.Expression: optimized expression
|
||||
"""
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
|
||||
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
|
||||
expression = expression.copy()
|
||||
for rule in rules:
|
||||
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
|
||||
|
||||
expression = rule(expression, **rule_kwargs)
|
||||
return expression
|
||||
|
|
|
@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
# SELECTION TO USE IF SELECTION LIST IS EMPTY
|
||||
DEFAULT_SELECTION = alias("1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression):
|
||||
"""
|
||||
|
@ -25,7 +28,8 @@ def pushdown_projections(expression):
|
|||
"""
|
||||
# Map of Scope to all columns being selected by outer queries.
|
||||
referenced_columns = defaultdict(set)
|
||||
|
||||
left_union = None
|
||||
right_union = None
|
||||
# We build the scope tree (which is traversed in DFS postorder), then iterate
|
||||
# over the result in reverse order. This should ensure that the set of selected
|
||||
# columns for a particular scope are completely build by the time we get to it.
|
||||
|
@ -37,12 +41,16 @@ def pushdown_projections(expression):
|
|||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
referenced_columns[right] = parent_selections
|
||||
left_union, right_union = scope.union_scopes
|
||||
referenced_columns[left_union] = parent_selections
|
||||
referenced_columns[right_union] = parent_selections
|
||||
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
_remove_unused_selections(scope, parent_selections)
|
||||
if isinstance(scope.expression, exp.Select) and scope != right_union:
|
||||
removed_indexes = _remove_unused_selections(scope, parent_selections)
|
||||
# The left union is used for column names to select and if we remove columns from the left
|
||||
# we need to also remove those same columns in the right that were at the same position
|
||||
if scope is left_union:
|
||||
_remove_indexed_selections(right_union, removed_indexes)
|
||||
|
||||
# Group columns by source name
|
||||
selects = defaultdict(set)
|
||||
|
@ -61,6 +69,7 @@ def pushdown_projections(expression):
|
|||
|
||||
|
||||
def _remove_unused_selections(scope, parent_selections):
|
||||
removed_indexes = []
|
||||
order = scope.expression.args.get("order")
|
||||
|
||||
if order:
|
||||
|
@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
order_refs = set()
|
||||
|
||||
new_selections = []
|
||||
for selection in scope.selects:
|
||||
for i, selection in enumerate(scope.selects):
|
||||
if (
|
||||
SELECT_ALL in parent_selections
|
||||
or selection.alias_or_name in parent_selections
|
||||
or selection.alias_or_name in order_refs
|
||||
):
|
||||
new_selections.append(selection)
|
||||
else:
|
||||
removed_indexes.append(i)
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(alias("1", "_"))
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _remove_indexed_selections(scope, indexes_to_remove):
|
||||
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -2,8 +2,8 @@ import itertools
|
|||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.schema import ensure_schema
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
||||
|
||||
def qualify_columns(expression, schema):
|
||||
|
@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
|
||||
"""
|
||||
for derived_table in derived_tables:
|
||||
if isinstance(derived_table, exp.UDTF):
|
||||
if isinstance(derived_table.unnest(), exp.UDTF):
|
||||
continue
|
||||
table_alias = derived_table.args.get("alias")
|
||||
if table_alias:
|
||||
|
@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver):
|
|||
if column_table:
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
for ordered in scope.find_all(exp.Ordered):
|
||||
for column in ordered.find_all(exp.Column):
|
||||
column_table = column.table
|
||||
column_name = column.name
|
||||
|
||||
if column_table or column.parent is ordered or column_name not in resolver.all_columns:
|
||||
continue
|
||||
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
if column_table is None:
|
||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
||||
|
||||
column.set("table", exp.to_identifier(column_table))
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver):
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
@ -346,6 +362,11 @@ class _Resolver:
|
|||
except Exception as e:
|
||||
raise OptimizeError(str(e)) from e
|
||||
|
||||
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
|
||||
values_alias = source.expression.parent
|
||||
if hasattr(values_alias, "alias_column_names"):
|
||||
return values_alias.alias_column_names
|
||||
|
||||
# Otherwise, if referencing another scope, return that scope's named selects
|
||||
return source.expression.named_selects
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None):
|
|||
if not source.args.get("catalog"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not isinstance(source.parent, exp.Alias):
|
||||
if not source.alias:
|
||||
source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
|
|
|
@ -1,180 +0,0 @@
|
|||
import abc
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import csv_reader
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table, only_visible=False):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
Schema based on a nested mapping.
|
||||
|
||||
Args:
|
||||
schema (dict): Mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema, visible=None, dialect=None):
|
||||
self.schema = schema
|
||||
self.visible = visible
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
|
||||
depth = _dict_depth(schema)
|
||||
|
||||
if not depth: # {}
|
||||
self.supported_table_args = []
|
||||
elif depth == 2: # {table: {col: type}}
|
||||
self.supported_table_args = ("this",)
|
||||
elif depth == 3: # {db: {table: {col: type}}}
|
||||
self.supported_table_args = ("db", "this")
|
||||
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
|
||||
self.supported_table_args = ("catalog", "db", "this")
|
||||
else:
|
||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
def column_names(self, table, only_visible=False):
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
||||
|
||||
for forbidden in self.forbidden_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
|
||||
|
||||
def fs_get(table):
|
||||
name = table.this.name
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
with csv_reader(table) as reader:
|
||||
return next(reader)
|
||||
|
||||
raise ValueError(f"Cannot read schema for {table}")
|
||||
|
||||
|
||||
def _nested_get(d, *path):
|
||||
"""
|
||||
Get a value for a nested dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
*path (tuple[str, str]): tuples of (name, key)
|
||||
`key` is the key in the dictionary to get.
|
||||
`name` is a string to use in the error if `key` isn't found.
|
||||
"""
|
||||
for name, key in path:
|
||||
d = d.get(key)
|
||||
if d is None:
|
||||
name = "table" if name == "this" else name
|
||||
raise ValueError(f"Unknown {name}")
|
||||
return d
|
||||
|
||||
|
||||
def _dict_depth(d):
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
For example:
|
||||
>>> _dict_depth(None)
|
||||
0
|
||||
>>> _dict_depth({})
|
||||
1
|
||||
>>> _dict_depth({"a": "b"})
|
||||
1
|
||||
>>> _dict_depth({"a": {}})
|
||||
2
|
||||
>>> _dict_depth({"a": {"b": {}}})
|
||||
3
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
try:
|
||||
return 1 + _dict_depth(next(iter(d.values())))
|
||||
except AttributeError:
|
||||
# d doesn't have attribute "values"
|
||||
return 0
|
||||
except StopIteration:
|
||||
# d.values() returns an empty sequence
|
||||
return 1
|
|
@ -257,12 +257,7 @@ class Scope:
|
|||
referenced_names = []
|
||||
|
||||
for table in self.tables:
|
||||
referenced_names.append(
|
||||
(
|
||||
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
|
||||
table,
|
||||
)
|
||||
)
|
||||
referenced_names.append((table.alias_or_name, table))
|
||||
for derived_table in self.derived_tables:
|
||||
referenced_names.append((derived_table.alias, derived_table.unnest()))
|
||||
|
||||
|
@ -538,8 +533,8 @@ def _add_table_sources(scope):
|
|||
for table in scope.tables:
|
||||
table_name = table.name
|
||||
|
||||
if isinstance(table.parent, exp.Alias):
|
||||
source_name = table.parent.alias
|
||||
if table.alias:
|
||||
source_name = table.alias
|
||||
else:
|
||||
source_name = table_name
|
||||
|
||||
|
|
|
@ -329,6 +329,7 @@ class Parser:
|
|||
exp.DataType: lambda self: self._parse_types(),
|
||||
exp.From: lambda self: self._parse_from(),
|
||||
exp.Group: lambda self: self._parse_group(),
|
||||
exp.Identifier: lambda self: self._parse_id_var(),
|
||||
exp.Lateral: lambda self: self._parse_lateral(),
|
||||
exp.Join: lambda self: self._parse_join(),
|
||||
exp.Order: lambda self: self._parse_order(),
|
||||
|
@ -371,11 +372,8 @@ class Parser:
|
|||
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
|
||||
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
|
||||
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
|
||||
TokenType.INTRODUCER: lambda self, token: self.expression(
|
||||
exp.Introducer,
|
||||
this=token.text,
|
||||
expression=self._parse_var_or_string(),
|
||||
),
|
||||
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
|
||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
|
@ -500,7 +498,7 @@ class Parser:
|
|||
max_errors=3,
|
||||
null_ordering=None,
|
||||
):
|
||||
self.error_level = error_level or ErrorLevel.RAISE
|
||||
self.error_level = error_level or ErrorLevel.IMMEDIATE
|
||||
self.error_message_context = error_message_context
|
||||
self.index_offset = index_offset
|
||||
self.unnest_column_only = unnest_column_only
|
||||
|
@ -928,6 +926,7 @@ class Parser:
|
|||
return self.expression(
|
||||
exp.Delete,
|
||||
this=self._parse_table(schema=True),
|
||||
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
|
||||
where=self._parse_where(),
|
||||
)
|
||||
|
||||
|
@ -1148,7 +1147,7 @@ class Parser:
|
|||
|
||||
def _parse_annotation(self, expression):
|
||||
if self._match(TokenType.ANNOTATION):
|
||||
return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
|
||||
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -1277,7 +1276,7 @@ class Parser:
|
|||
alias = self._parse_table_alias()
|
||||
|
||||
if alias:
|
||||
this = self.expression(exp.Alias, this=this, alias=alias)
|
||||
this.set("alias", alias)
|
||||
|
||||
if not self.alias_post_tablesample:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
@ -1876,6 +1875,17 @@ class Parser:
|
|||
self._match_r_paren()
|
||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
def _parse_introducer(self, token):
|
||||
literal = self._parse_primary()
|
||||
if literal:
|
||||
return self.expression(
|
||||
exp.Introducer,
|
||||
this=token.text,
|
||||
expression=literal,
|
||||
)
|
||||
|
||||
return self.expression(exp.Identifier, this=token.text)
|
||||
|
||||
def _parse_udf_kwarg(self):
|
||||
this = self._parse_id_var()
|
||||
kind = self._parse_types()
|
||||
|
|
|
@ -199,13 +199,14 @@ class Step:
|
|||
class Scan(Step):
|
||||
@classmethod
|
||||
def from_expression(cls, expression, ctes=None):
|
||||
table = expression.this
|
||||
table = expression
|
||||
alias_ = expression.alias
|
||||
|
||||
if not alias_:
|
||||
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
|
||||
|
||||
if isinstance(expression, exp.Subquery):
|
||||
table = expression.this
|
||||
step = Step.from_expression(table, ctes)
|
||||
step.name = alias_
|
||||
return step
|
||||
|
|
297
sqlglot/schema.py
Normal file
297
sqlglot/schema.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
import abc
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import csv_reader
|
||||
|
||||
|
||||
class Schema(abc.ABC):
|
||||
"""Abstract base class for database schemas"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_table(self, table, column_mapping=None):
|
||||
"""
|
||||
Register or update a table. Some implementing classes may require column information to also be provided
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def column_names(self, table, only_visible=False):
|
||||
"""
|
||||
Get the column names for a table.
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): Table expression instance
|
||||
only_visible (bool): Whether to include invisible columns
|
||||
Returns:
|
||||
list[str]: list of column names
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_column_type(self, table, column):
|
||||
"""
|
||||
Get the exp.DataType type of a column in the schema.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table): The source table.
|
||||
column (sqlglot.expressions.Column): The target column.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting column type.
|
||||
"""
|
||||
|
||||
|
||||
class MappingSchema(Schema):
|
||||
"""
|
||||
Schema based on a nested mapping.
|
||||
|
||||
Args:
|
||||
schema (dict): Mapping in one of the following forms:
|
||||
1. {table: {col: type}}
|
||||
2. {db: {table: {col: type}}}
|
||||
3. {catalog: {db: {table: {col: type}}}}
|
||||
4. None - Tables will be added later
|
||||
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
|
||||
are assumed to be visible. The nesting should mirror that of the schema:
|
||||
1. {table: set(*cols)}}
|
||||
2. {db: {table: set(*cols)}}}
|
||||
3. {catalog: {db: {table: set(*cols)}}}}
|
||||
dialect (str): The dialect to be used for custom type mappings.
|
||||
"""
|
||||
|
||||
def __init__(self, schema=None, visible=None, dialect=None):
|
||||
self.schema = schema or {}
|
||||
self.visible = visible
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache = {}
|
||||
self.supported_table_args = []
|
||||
self.forbidden_table_args = set()
|
||||
if self.schema:
|
||||
self._initialize_supported_args()
|
||||
|
||||
@classmethod
|
||||
def from_mapping_schema(cls, mapping_schema):
|
||||
return MappingSchema(
|
||||
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
|
||||
)
|
||||
|
||||
def copy(self, **kwargs):
|
||||
return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
|
||||
|
||||
def add_table(self, table, column_mapping=None):
|
||||
"""
|
||||
Register or update a table. Updates are only performed if a new column mapping is provided.
|
||||
|
||||
Args:
|
||||
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
|
||||
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
|
||||
"""
|
||||
table = exp.to_table(table)
|
||||
self._validate_table(table)
|
||||
column_mapping = ensure_column_mapping(column_mapping)
|
||||
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
|
||||
existing_column_mapping = _nested_get(
|
||||
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
|
||||
)
|
||||
if existing_column_mapping and not column_mapping:
|
||||
return
|
||||
_nested_set(
|
||||
self.schema,
|
||||
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
|
||||
column_mapping,
|
||||
)
|
||||
self._initialize_supported_args()
|
||||
|
||||
def _get_table_args_from_table(self, table):
|
||||
if table.args.get("catalog") is not None:
|
||||
return "catalog", "db", "this"
|
||||
if table.args.get("db") is not None:
|
||||
return "db", "this"
|
||||
return ("this",)
|
||||
|
||||
def _validate_table(self, table):
|
||||
if not self.supported_table_args and isinstance(table, exp.Table):
|
||||
return
|
||||
for forbidden in self.forbidden_table_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
for expected in self.supported_table_args:
|
||||
if not table.text(expected):
|
||||
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
|
||||
|
||||
def column_names(self, table, only_visible=False):
|
||||
table = exp.to_table(table)
|
||||
if not isinstance(table.this, exp.Identifier):
|
||||
return fs_get(table)
|
||||
|
||||
args = tuple(table.text(p) for p in self.supported_table_args)
|
||||
|
||||
for forbidden in self.forbidden_table_args:
|
||||
if table.text(forbidden):
|
||||
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
|
||||
|
||||
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
|
||||
if not only_visible or not self.visible:
|
||||
return columns
|
||||
|
||||
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
|
||||
return [col for col in columns if col in visible]
|
||||
|
||||
def get_column_type(self, table, column):
|
||||
try:
|
||||
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
|
||||
return self._convert_type(schema_type)
|
||||
except:
|
||||
raise OptimizeError(f"Failed to get type for column {column.sql()}")
|
||||
|
||||
def _convert_type(self, schema_type):
|
||||
"""
|
||||
Convert a type represented as a string to the corresponding exp.DataType.Type object.
|
||||
Args:
|
||||
schema_type (str): The type we want to convert.
|
||||
Returns:
|
||||
sqlglot.expressions.DataType.Type: The resulting expression type.
|
||||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
try:
|
||||
self._type_mapping_cache[schema_type] = exp.maybe_parse(
|
||||
schema_type, into=exp.DataType, dialect=self.dialect
|
||||
).this
|
||||
except AttributeError:
|
||||
raise OptimizeError(f"Failed to convert type {schema_type}")
|
||||
|
||||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
def _initialize_supported_args(self):
|
||||
if not self.supported_table_args:
|
||||
depth = _dict_depth(self.schema)
|
||||
|
||||
all_args = ["this", "db", "catalog"]
|
||||
if not depth or depth == 1: # {}
|
||||
self.supported_table_args = []
|
||||
elif 2 <= depth <= 4:
|
||||
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
|
||||
else:
|
||||
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
|
||||
|
||||
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
|
||||
|
||||
|
||||
def ensure_schema(schema):
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
||||
return MappingSchema(schema)
|
||||
|
||||
|
||||
def ensure_column_mapping(mapping):
|
||||
if isinstance(mapping, dict):
|
||||
return mapping
|
||||
elif isinstance(mapping, str):
|
||||
col_name_type_strs = [x.strip() for x in mapping.split(",")]
|
||||
return {
|
||||
name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
|
||||
for name_type_str in col_name_type_strs
|
||||
}
|
||||
# Check if mapping looks like a DataFrame StructType
|
||||
elif hasattr(mapping, "simpleString"):
|
||||
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
|
||||
elif isinstance(mapping, list):
|
||||
return {x.strip(): None for x in mapping}
|
||||
elif mapping is None:
|
||||
return {}
|
||||
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
|
||||
|
||||
|
||||
def fs_get(table):
|
||||
name = table.this.name
|
||||
|
||||
if name.upper() == "READ_CSV":
|
||||
with csv_reader(table) as reader:
|
||||
return next(reader)
|
||||
|
||||
raise ValueError(f"Cannot read schema for {table}")
|
||||
|
||||
|
||||
def _nested_get(d, *path, raise_on_missing=True):
|
||||
"""
|
||||
Get a value for a nested dictionary.
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
*path (tuple[str, str]): tuples of (name, key)
|
||||
`key` is the key in the dictionary to get.
|
||||
`name` is a string to use in the error if `key` isn't found.
|
||||
|
||||
Returns:
|
||||
The value or None if it doesn't exist
|
||||
"""
|
||||
for name, key in path:
|
||||
d = d.get(key)
|
||||
if d is None:
|
||||
if raise_on_missing:
|
||||
name = "table" if name == "this" else name
|
||||
raise ValueError(f"Unknown {name}")
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def _nested_set(d, keys, value):
|
||||
"""
|
||||
In-place set a value for a nested dictionary
|
||||
|
||||
Ex:
|
||||
>>> _nested_set({}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'second_key': 'value'}}
|
||||
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
|
||||
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
|
||||
|
||||
d (dict): dictionary
|
||||
keys (Iterable[str]): ordered iterable of keys that makeup path to value
|
||||
value (Any): The value to set in the dictionary for the given key path
|
||||
"""
|
||||
if not keys:
|
||||
return
|
||||
if len(keys) == 1:
|
||||
d[keys[0]] = value
|
||||
return
|
||||
subd = d
|
||||
for key in keys[:-1]:
|
||||
if key not in subd:
|
||||
subd = subd.setdefault(key, {})
|
||||
else:
|
||||
subd = subd[key]
|
||||
subd[keys[-1]] = value
|
||||
return d
|
||||
|
||||
|
||||
def _dict_depth(d):
|
||||
"""
|
||||
Get the nesting depth of a dictionary.
|
||||
|
||||
For example:
|
||||
>>> _dict_depth(None)
|
||||
0
|
||||
>>> _dict_depth({})
|
||||
1
|
||||
>>> _dict_depth({"a": "b"})
|
||||
1
|
||||
>>> _dict_depth({"a": {}})
|
||||
2
|
||||
>>> _dict_depth({"a": {"b": {}}})
|
||||
3
|
||||
|
||||
Args:
|
||||
d (dict): dictionary
|
||||
Returns:
|
||||
int: depth
|
||||
"""
|
||||
try:
|
||||
return 1 + _dict_depth(next(iter(d.values())))
|
||||
except AttributeError:
|
||||
# d doesn't have attribute "values"
|
||||
return 0
|
||||
except StopIteration:
|
||||
# d.values() returns an empty sequence
|
||||
return 1
|
|
@ -56,6 +56,7 @@ class TokenType(AutoName):
|
|||
VAR = auto()
|
||||
BIT_STRING = auto()
|
||||
HEX_STRING = auto()
|
||||
BYTE_STRING = auto()
|
||||
|
||||
# types
|
||||
BOOLEAN = auto()
|
||||
|
@ -320,6 +321,7 @@ class _Tokenizer(type):
|
|||
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
||||
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
|
||||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||
klass._COMMENTS = dict(
|
||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
|
||||
|
@ -333,6 +335,7 @@ class _Tokenizer(type):
|
|||
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
|
||||
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
|
||||
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
|
||||
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
|
||||
}.items()
|
||||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||
)
|
||||
|
@ -385,6 +388,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
HEX_STRINGS = []
|
||||
|
||||
BYTE_STRINGS = []
|
||||
|
||||
IDENTIFIERS = ['"']
|
||||
|
||||
ESCAPE = "'"
|
||||
|
@ -799,7 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
if self._scan_string(word):
|
||||
return
|
||||
if self._scan_numeric_string(word):
|
||||
if self._scan_formatted_string(word):
|
||||
return
|
||||
if self._scan_comment(word):
|
||||
return
|
||||
|
@ -906,7 +911,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._add(TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
def _scan_numeric_string(self, string_start):
|
||||
# X'1234, b'0110', E'\\\\\' etc.
|
||||
def _scan_formatted_string(self, string_start):
|
||||
if string_start in self._HEX_STRINGS:
|
||||
delimiters = self._HEX_STRINGS
|
||||
token_type = TokenType.HEX_STRING
|
||||
|
@ -915,6 +921,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
delimiters = self._BIT_STRINGS
|
||||
token_type = TokenType.BIT_STRING
|
||||
base = 2
|
||||
elif string_start in self._BYTE_STRINGS:
|
||||
delimiters = self._BYTE_STRINGS
|
||||
token_type = TokenType.BYTE_STRING
|
||||
base = None
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -922,10 +932,14 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
string_end = delimiters.get(string_start)
|
||||
text = self._extract_string(string_end)
|
||||
|
||||
try:
|
||||
self._add(token_type, f"{int(text, base)}")
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
|
||||
if base is None:
|
||||
self._add(token_type, text)
|
||||
else:
|
||||
try:
|
||||
self._add(token_type, f"{int(text, base)}")
|
||||
except:
|
||||
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
|
||||
|
||||
return True
|
||||
|
||||
def _scan_identifier(self, identifier_end):
|
||||
|
|
0
tests/dataframe/__init__.py
Normal file
0
tests/dataframe/__init__.py
Normal file
0
tests/dataframe/integration/__init__.py
Normal file
0
tests/dataframe/integration/__init__.py
Normal file
149
tests/dataframe/integration/dataframe_validator.py
Normal file
149
tests/dataframe/integration/dataframe_validator.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
import typing as t
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import sqlglot
|
||||
from tests.helpers import SKIP_INTEGRATION
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from pyspark.sql import DataFrame as SparkDataFrame
|
||||
|
||||
|
||||
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
|
||||
class DataFrameValidator(unittest.TestCase):
|
||||
spark = None
|
||||
sqlglot = None
|
||||
df_employee = None
|
||||
df_store = None
|
||||
df_district = None
|
||||
spark_employee_schema = None
|
||||
sqlglot_employee_schema = None
|
||||
spark_store_schema = None
|
||||
sqlglot_store_schema = None
|
||||
spark_district_schema = None
|
||||
sqlglot_district_schema = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
from pyspark import SparkConf
|
||||
from pyspark.sql import SparkSession, types
|
||||
|
||||
from sqlglot.dataframe.sql import types as sqlglotSparkTypes
|
||||
from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
|
||||
|
||||
# This is for test `test_branching_root_dataframes`
|
||||
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
|
||||
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
|
||||
cls.spark.sparkContext.setLogLevel("ERROR")
|
||||
cls.sqlglot = SqlglotSparkSession()
|
||||
cls.spark_employee_schema = types.StructType(
|
||||
[
|
||||
types.StructField("employee_id", types.IntegerType(), False),
|
||||
types.StructField("fname", types.StringType(), False),
|
||||
types.StructField("lname", types.StringType(), False),
|
||||
types.StructField("age", types.IntegerType(), False),
|
||||
types.StructField("store_id", types.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
employee_data = [
|
||||
(1, "Jack", "Shephard", 37, 1),
|
||||
(2, "John", "Locke", 65, 1),
|
||||
(3, "Kate", "Austen", 37, 2),
|
||||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
|
||||
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
|
||||
cls.df_employee.createOrReplaceTempView("employee")
|
||||
|
||||
cls.spark_store_schema = types.StructType(
|
||||
[
|
||||
types.StructField("store_id", types.IntegerType(), False),
|
||||
types.StructField("store_name", types.StringType(), False),
|
||||
types.StructField("district_id", types.IntegerType(), False),
|
||||
types.StructField("num_sales", types.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
store_data = [
|
||||
(1, "Hydra", 1, 37),
|
||||
(2, "Arrow", 2, 2000),
|
||||
]
|
||||
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
|
||||
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
|
||||
cls.df_store.createOrReplaceTempView("store")
|
||||
|
||||
cls.spark_district_schema = types.StructType(
|
||||
[
|
||||
types.StructField("district_id", types.IntegerType(), False),
|
||||
types.StructField("district_name", types.StringType(), False),
|
||||
types.StructField("manager_name", types.StringType(), False),
|
||||
]
|
||||
)
|
||||
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
|
||||
[
|
||||
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
|
||||
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
|
||||
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
|
||||
]
|
||||
)
|
||||
district_data = [
|
||||
(1, "Temple", "Dogen"),
|
||||
(2, "Lighthouse", "Jacob"),
|
||||
]
|
||||
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
|
||||
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
|
||||
cls.df_district.createOrReplaceTempView("district")
|
||||
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
|
||||
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
|
||||
sqlglot.schema.add_table("district", cls.sqlglot_district_schema)
|
||||
|
||||
def setUp(self) -> None:
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning)
|
||||
self.df_spark_store = self.df_store.alias("df_store") # type: ignore
|
||||
self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
|
||||
self.df_spark_district = self.df_district.alias("df_district") # type: ignore
|
||||
self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
|
||||
self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
|
||||
self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
|
||||
|
||||
def compare_spark_with_sqlglot(
|
||||
self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
|
||||
) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
|
||||
def compare_schemas(schema_1, schema_2):
|
||||
for schema in [schema_1, schema_2]:
|
||||
for struct_field in schema.fields:
|
||||
struct_field.metadata = {}
|
||||
self.assertEqual(schema_1, schema_2)
|
||||
|
||||
for statement in df_sqlglot.sql():
|
||||
actual_df_sqlglot = self.spark.sql(statement) # type: ignore
|
||||
df_sqlglot_results = actual_df_sqlglot.collect()
|
||||
df_spark_results = df_spark.collect()
|
||||
if not skip_schema_compare:
|
||||
compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
|
||||
self.assertEqual(df_spark_results, df_sqlglot_results)
|
||||
if no_empty:
|
||||
self.assertNotEqual(len(df_spark_results), 0)
|
||||
self.assertNotEqual(len(df_sqlglot_results), 0)
|
||||
return df_spark, actual_df_sqlglot
|
||||
|
||||
@classmethod
|
||||
def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
|
||||
return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore
|
1103
tests/dataframe/integration/test_dataframe.py
Normal file
1103
tests/dataframe/integration/test_dataframe.py
Normal file
File diff suppressed because it is too large
Load diff
71
tests/dataframe/integration/test_grouped_data.py
Normal file
71
tests/dataframe/integration/test_grouped_data.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from pyspark.sql import functions as F
|
||||
|
||||
from sqlglot.dataframe.sql import functions as SF
|
||||
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
|
||||
|
||||
|
||||
class TestDataframeFunc(DataFrameValidator):
|
||||
def test_group_by(self):
|
||||
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
|
||||
F.min(self.df_spark_employee.employee_id)
|
||||
)
|
||||
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
|
||||
SF.min(self.df_sqlglot_employee.employee_id)
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
|
||||
|
||||
def test_group_by_where_non_aggregate(self):
|
||||
df_employee = (
|
||||
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
||||
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
||||
.where(F.col("age") > F.lit(50))
|
||||
)
|
||||
dfs_employee = (
|
||||
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
||||
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
||||
.where(SF.col("age") > SF.lit(50))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_group_by_where_aggregate_like_having(self):
|
||||
df_employee = (
|
||||
self.df_spark_employee.groupBy(self.df_spark_employee.age)
|
||||
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
|
||||
.where(F.col("min_employee_id") > F.lit(1))
|
||||
)
|
||||
dfs_employee = (
|
||||
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
|
||||
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
|
||||
.where(SF.col("min_employee_id") > SF.lit(1))
|
||||
)
|
||||
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
|
||||
|
||||
def test_count(self):
|
||||
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
|
||||
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_mean(self):
|
||||
df = self.df_spark_employee.groupBy().mean("age", "store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_avg(self):
|
||||
df = self.df_spark_employee.groupBy("age").avg("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_max(self):
|
||||
df = self.df_spark_employee.groupBy("age").max("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_min(self):
|
||||
df = self.df_spark_employee.groupBy("age").min("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_sum(self):
|
||||
df = self.df_spark_employee.groupBy("age").sum("store_id")
|
||||
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
28
tests/dataframe/integration/test_session.py
Normal file
28
tests/dataframe/integration/test_session.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from pyspark.sql import functions as F
|
||||
|
||||
from sqlglot.dataframe.sql import functions as SF
|
||||
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
|
||||
|
||||
|
||||
class TestSessionFunc(DataFrameValidator):
|
||||
def test_sql_simple_select(self):
|
||||
query = "SELECT fname, lname FROM employee"
|
||||
df = self.spark.sql(query)
|
||||
dfs = self.sqlglot.sql(query)
|
||||
self.compare_spark_with_sqlglot(df, dfs)
|
||||
|
||||
def test_sql_with_join(self):
|
||||
query = """
|
||||
SELECT
|
||||
e.employee_id
|
||||
, s.store_id
|
||||
FROM
|
||||
employee e
|
||||
INNER JOIN
|
||||
store s
|
||||
ON
|
||||
e.store_id = s.store_id
|
||||
"""
|
||||
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
|
||||
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
|
||||
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
|
0
tests/dataframe/unit/__init__.py
Normal file
0
tests/dataframe/unit/__init__.py
Normal file
35
tests/dataframe/unit/dataframe_sql_validator.py
Normal file
35
tests/dataframe/unit/dataframe_sql_validator.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import typing as t
|
||||
import unittest
|
||||
|
||||
from sqlglot.dataframe.sql import types
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
|
||||
class DataFrameSQLValidator(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.spark = SparkSession()
|
||||
self.employee_schema = types.StructType(
|
||||
[
|
||||
types.StructField("employee_id", types.IntegerType(), False),
|
||||
types.StructField("fname", types.StringType(), False),
|
||||
types.StructField("lname", types.StringType(), False),
|
||||
types.StructField("age", types.IntegerType(), False),
|
||||
types.StructField("store_id", types.IntegerType(), False),
|
||||
]
|
||||
)
|
||||
employee_data = [
|
||||
(1, "Jack", "Shephard", 37, 1),
|
||||
(2, "John", "Locke", 65, 1),
|
||||
(3, "Kate", "Austen", 37, 2),
|
||||
(4, "Claire", "Littleton", 27, 2),
|
||||
(5, "Hugo", "Reyes", 29, 100),
|
||||
]
|
||||
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
|
||||
|
||||
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
|
||||
actual_sqls = df.sql(pretty=pretty)
|
||||
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
|
||||
self.assertEqual(len(expected_statements), len(actual_sqls))
|
||||
for expected, actual in zip(expected_statements, actual_sqls):
|
||||
self.assertEqual(expected, actual)
|
167
tests/dataframe/unit/test_column.py
Normal file
167
tests/dataframe/unit/test_column.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
import datetime
|
||||
import unittest
|
||||
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.window import Window
|
||||
|
||||
|
||||
class TestDataframeColumn(unittest.TestCase):
|
||||
def test_eq(self):
|
||||
self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
|
||||
|
||||
def test_neq(self):
|
||||
self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
|
||||
|
||||
def test_gt(self):
|
||||
self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
|
||||
|
||||
def test_lt(self):
|
||||
self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
|
||||
|
||||
def test_le(self):
|
||||
self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
|
||||
|
||||
def test_ge(self):
|
||||
self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
|
||||
|
||||
def test_and(self):
|
||||
self.assertEqual(
|
||||
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
|
||||
)
|
||||
|
||||
def test_or(self):
|
||||
self.assertEqual(
|
||||
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
|
||||
)
|
||||
|
||||
def test_mod(self):
|
||||
self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
|
||||
|
||||
def test_add(self):
|
||||
self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
|
||||
|
||||
def test_sub(self):
|
||||
self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
|
||||
|
||||
def test_mul(self):
|
||||
self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
|
||||
|
||||
def test_div(self):
|
||||
self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
|
||||
|
||||
def test_radd(self):
|
||||
self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
|
||||
|
||||
def test_rsub(self):
|
||||
self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
|
||||
|
||||
def test_rmul(self):
|
||||
self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
|
||||
|
||||
def test_rdiv(self):
|
||||
self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
|
||||
|
||||
def test_pow(self):
|
||||
self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
|
||||
|
||||
def test_rpow(self):
|
||||
self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
|
||||
|
||||
def test_invert(self):
|
||||
self.assertEqual("NOT cola", (~F.col("cola")).sql())
|
||||
|
||||
def test_startswith(self):
|
||||
self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
|
||||
|
||||
def test_endswith(self):
|
||||
self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
|
||||
|
||||
def test_rlike(self):
|
||||
self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
|
||||
|
||||
def test_like(self):
|
||||
self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
|
||||
|
||||
def test_ilike(self):
|
||||
self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
|
||||
|
||||
def test_substring(self):
|
||||
self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
|
||||
|
||||
def test_isin(self):
|
||||
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
|
||||
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
|
||||
|
||||
def test_asc(self):
|
||||
self.assertEqual("cola", F.col("cola").asc().sql())
|
||||
|
||||
def test_desc(self):
|
||||
self.assertEqual("cola DESC", F.col("cola").desc().sql())
|
||||
|
||||
def test_asc_nulls_first(self):
|
||||
self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())
|
||||
|
||||
def test_asc_nulls_last(self):
|
||||
self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())
|
||||
|
||||
def test_desc_nulls_first(self):
|
||||
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
|
||||
|
||||
def test_desc_nulls_last(self):
|
||||
self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
|
||||
|
||||
def test_when_otherwise(self):
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
||||
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
|
||||
)
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
|
||||
F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
|
||||
)
|
||||
self.assertEqual(
|
||||
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
|
||||
F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
|
||||
)
|
||||
|
||||
def test_is_null(self):
|
||||
self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
|
||||
|
||||
def test_is_not_null(self):
|
||||
self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
|
||||
|
||||
def test_cast(self):
|
||||
self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
|
||||
|
||||
def test_alias(self):
|
||||
self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
|
||||
|
||||
def test_between(self):
|
||||
self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
|
||||
self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
|
||||
self.assertEqual(
|
||||
"cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')",
|
||||
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
|
||||
)
|
||||
self.assertEqual(
|
||||
"cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)",
|
||||
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
|
||||
)
|
||||
|
||||
def test_over(self):
|
||||
over_rows = F.sum("cola").over(
|
||||
Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
|
||||
)
|
||||
self.assertEqual(
|
||||
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
over_rows.sql(),
|
||||
)
|
||||
over_range = F.sum("cola").over(
|
||||
Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
|
||||
)
|
||||
self.assertEqual(
|
||||
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
|
||||
over_range.sql(),
|
||||
)
|
39
tests/dataframe/unit/test_dataframe.py
Normal file
39
tests/dataframe/unit/test_dataframe.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||
|
||||
|
||||
class TestDataframe(DataFrameSQLValidator):
|
||||
def test_hash_select_expression(self):
|
||||
expression = exp.select("cola").from_("table")
|
||||
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
|
||||
|
||||
def test_columns(self):
|
||||
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
|
||||
|
||||
def test_cache(self):
|
||||
df = self.df_employee.select("fname").cache()
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
def test_persist_default(self):
|
||||
df = self.df_employee.select("fname").persist()
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
def test_persist_storagelevel(self):
|
||||
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t11623",
|
||||
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
86
tests/dataframe/unit/test_dataframe_writer.py
Normal file
86
tests/dataframe/unit/test_dataframe_writer.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from unittest import mock
|
||||
|
||||
import sqlglot
|
||||
from sqlglot.schema import MappingSchema
|
||||
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||
|
||||
|
||||
class TestDataFrameWriter(DataFrameSQLValidator):
|
||||
def test_insertInto_full_path(self):
|
||||
df = self.df_employee.write.insertInto("catalog.db.table_name")
|
||||
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_db_table(self):
|
||||
df = self.df_employee.write.insertInto("db.table_name")
|
||||
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_table(self):
|
||||
df = self.df_employee.write.insertInto("table_name")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_overwrite(self):
|
||||
df = self.df_employee.write.insertInto("table_name", overwrite=True)
|
||||
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_insertInto_byName(self):
|
||||
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
|
||||
df = self.df_employee.write.byName.insertInto("table_name")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_insertInto_cache(self):
|
||||
df = self.df_employee.cache().write.insertInto("table_name")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t35612",
|
||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
||||
|
||||
def test_saveAsTable_format(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
|
||||
|
||||
def test_saveAsTable_append(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="append")
|
||||
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_overwrite(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_error(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="error")
|
||||
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_ignore(self):
|
||||
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_mode_standalone(self):
|
||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
|
||||
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_mode_override(self):
|
||||
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
|
||||
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_saveAsTable_cache(self):
|
||||
df = self.df_employee.cache().write.saveAsTable("table_name")
|
||||
expected_statements = [
|
||||
"DROP VIEW IF EXISTS t35612",
|
||||
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
|
||||
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
|
||||
]
|
||||
self.compare_sql(df, expected_statements)
|
1593
tests/dataframe/unit/test_functions.py
Normal file
1593
tests/dataframe/unit/test_functions.py
Normal file
File diff suppressed because it is too large
Load diff
114
tests/dataframe/unit/test_session.py
Normal file
114
tests/dataframe/unit/test_session.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
from unittest import mock
|
||||
|
||||
import sqlglot
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql import types
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.schema import MappingSchema
|
||||
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
|
||||
|
||||
|
||||
class TestDataframeSession(DataFrameSQLValidator):
|
||||
def test_cdf_one_row(self):
|
||||
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_multiple_rows(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_no_schema(self):
|
||||
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
|
||||
expected = (
|
||||
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
|
||||
)
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_row_mixed_primitives(self):
|
||||
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
|
||||
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_dict_rows(self):
|
||||
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
|
||||
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_cdf_str_schema(self):
|
||||
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
|
||||
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_typed_schema_basic(self):
|
||||
schema = types.StructType(
|
||||
[
|
||||
types.StructField("cola", types.IntegerType()),
|
||||
types.StructField("colb", types.StringType()),
|
||||
]
|
||||
)
|
||||
df = self.spark.createDataFrame([[1, "test"]], schema)
|
||||
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_typed_schema_nested(self):
|
||||
schema = types.StructType(
|
||||
[
|
||||
types.StructField(
|
||||
"cola",
|
||||
types.StructType(
|
||||
[
|
||||
types.StructField("sub_cola", types.IntegerType()),
|
||||
types.StructField("sub_colb", types.StringType()),
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
|
||||
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_sql_select_only(self):
|
||||
# TODO: Do exact matches once CTE names are deterministic
|
||||
query = "SELECT cola, colb FROM table"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
self.assertIn(
|
||||
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
|
||||
)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_sql_with_aggs(self):
|
||||
# TODO: Do exact matches once CTE names are deterministic
|
||||
query = "SELECT cola, colb FROM table"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
|
||||
result = df.sql(pretty=False, optimize=False)[0]
|
||||
self.assertIn("SELECT cola, colb FROM table", result)
|
||||
self.assertIn("SUM(colb)", result)
|
||||
self.assertIn("GROUP BY cola", result)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_sql_create(self):
|
||||
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
@mock.patch("sqlglot.schema", MappingSchema())
|
||||
def test_sql_insert(self):
|
||||
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
|
||||
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
|
||||
df = self.spark.sql(query)
|
||||
expected = (
|
||||
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
|
||||
)
|
||||
self.compare_sql(df, expected)
|
||||
|
||||
def test_session_create_builder_patterns(self):
|
||||
spark = SparkSession()
|
||||
self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark)
|
70
tests/dataframe/unit/test_types.py
Normal file
70
tests/dataframe/unit/test_types.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot.dataframe.sql import types
|
||||
|
||||
|
||||
class TestDataframeTypes(unittest.TestCase):
|
||||
def test_string(self):
|
||||
self.assertEqual("string", types.StringType().simpleString())
|
||||
|
||||
def test_char(self):
|
||||
self.assertEqual("char(100)", types.CharType(100).simpleString())
|
||||
|
||||
def test_varchar(self):
|
||||
self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
|
||||
|
||||
def test_binary(self):
|
||||
self.assertEqual("binary", types.BinaryType().simpleString())
|
||||
|
||||
def test_boolean(self):
|
||||
self.assertEqual("boolean", types.BooleanType().simpleString())
|
||||
|
||||
def test_date(self):
|
||||
self.assertEqual("date", types.DateType().simpleString())
|
||||
|
||||
def test_timestamp(self):
|
||||
self.assertEqual("timestamp", types.TimestampType().simpleString())
|
||||
|
||||
def test_timestamp_ntz(self):
|
||||
self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
|
||||
|
||||
def test_decimal(self):
|
||||
self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
|
||||
|
||||
def test_double(self):
|
||||
self.assertEqual("double", types.DoubleType().simpleString())
|
||||
|
||||
def test_float(self):
|
||||
self.assertEqual("float", types.FloatType().simpleString())
|
||||
|
||||
def test_byte(self):
|
||||
self.assertEqual("tinyint", types.ByteType().simpleString())
|
||||
|
||||
def test_integer(self):
|
||||
self.assertEqual("int", types.IntegerType().simpleString())
|
||||
|
||||
def test_long(self):
|
||||
self.assertEqual("bigint", types.LongType().simpleString())
|
||||
|
||||
def test_short(self):
|
||||
self.assertEqual("smallint", types.ShortType().simpleString())
|
||||
|
||||
def test_array(self):
|
||||
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
|
||||
|
||||
def test_map(self):
|
||||
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
|
||||
|
||||
def test_struct_field(self):
|
||||
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
|
||||
|
||||
def test_struct_type(self):
|
||||
self.assertEqual(
|
||||
"struct<cola:int, colb:string>",
|
||||
types.StructType(
|
||||
[
|
||||
types.StructField("cola", types.IntegerType()),
|
||||
types.StructField("colb", types.StringType()),
|
||||
]
|
||||
).simpleString(),
|
||||
)
|
60
tests/dataframe/unit/test_window.py
Normal file
60
tests/dataframe/unit/test_window.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.window import Window, WindowSpec
|
||||
|
||||
|
||||
class TestDataframeWindow(unittest.TestCase):
|
||||
def test_window_spec_partition_by(self):
|
||||
partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
|
||||
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
|
||||
|
||||
def test_window_spec_order_by(self):
|
||||
order_by = WindowSpec().orderBy("cola", "colb")
|
||||
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
|
||||
|
||||
def test_window_spec_rows_between(self):
|
||||
rows_between = WindowSpec().rowsBetween(3, 5)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
|
||||
def test_window_spec_range_between(self):
|
||||
range_between = WindowSpec().rangeBetween(3, 5)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
|
||||
def test_window_partition_by(self):
|
||||
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
|
||||
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
|
||||
|
||||
def test_window_order_by(self):
|
||||
order_by = Window.orderBy("cola", "colb")
|
||||
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
|
||||
|
||||
def test_window_rows_between(self):
|
||||
rows_between = Window.rowsBetween(3, 5)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
|
||||
|
||||
def test_window_range_between(self):
|
||||
range_between = Window.rangeBetween(3, 5)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
|
||||
|
||||
def test_window_rows_unbounded(self):
|
||||
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
|
||||
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
|
||||
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
|
||||
)
|
||||
|
||||
def test_window_range_unbounded(self):
|
||||
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
|
||||
)
|
||||
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
|
||||
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
|
||||
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
|
||||
self.assertEqual(
|
||||
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
|
||||
)
|
|
@ -694,29 +694,6 @@ class TestDialect(Validator):
|
|||
},
|
||||
)
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/join.html
|
||||
# https://www.postgresql.org/docs/current/queries-table-expressions.html
|
||||
def test_joined_tables(self):
|
||||
self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
|
||||
self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
|
||||
self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
|
||||
self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
|
||||
|
||||
self.validate_all(
|
||||
"SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||
write={
|
||||
"postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||
"mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||
write={
|
||||
"postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||
"mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_lateral_subquery(self):
|
||||
self.validate_identity(
|
||||
"SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
|
||||
|
@ -856,7 +833,7 @@ class TestDialect(Validator):
|
|||
"postgres": "x ILIKE '%y'",
|
||||
"presto": "LOWER(x) LIKE '%y'",
|
||||
"snowflake": "x ILIKE '%y'",
|
||||
"spark": "LOWER(x) LIKE '%y'",
|
||||
"spark": "x ILIKE '%y'",
|
||||
"sqlite": "LOWER(x) LIKE '%y'",
|
||||
"starrocks": "LOWER(x) LIKE '%y'",
|
||||
"trino": "LOWER(x) LIKE '%y'",
|
||||
|
|
|
@ -48,7 +48,7 @@ class TestDuckDB(Validator):
|
|||
self.validate_all(
|
||||
"STRPTIME(x, '%y-%-m')",
|
||||
write={
|
||||
"bigquery": "STR_TO_TIME(x, '%y-%-m')",
|
||||
"bigquery": "PARSE_TIMESTAMP('%y-%m', x)",
|
||||
"duckdb": "STRPTIME(x, '%y-%-m')",
|
||||
"presto": "DATE_PARSE(x, '%y-%c')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)",
|
||||
|
@ -63,6 +63,16 @@ class TestDuckDB(Validator):
|
|||
"hive": "CAST(x AS TIMESTAMP)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
|
||||
write={
|
||||
"bigquery": "PARSE_TIMESTAMP('%m/%d/%y %I:%M %p', x)",
|
||||
"duckdb": "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
|
||||
"presto": "DATE_PARSE(x, '%c/%e/%y %l:%i %p')",
|
||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'M/d/yy h:mm a')) AS TIMESTAMP)",
|
||||
"spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')",
|
||||
},
|
||||
)
|
||||
|
||||
def test_duckdb(self):
|
||||
self.validate_all(
|
||||
|
@ -268,6 +278,17 @@ class TestDuckDB(Validator):
|
|||
"spark": "MONTH('2021-03-01')",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
|
||||
write={
|
||||
"duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
|
||||
"presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
|
||||
"hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
|
||||
"spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
|
||||
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
|
||||
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
|
||||
},
|
||||
)
|
||||
|
||||
with self.assertRaises(UnsupportedError):
|
||||
transpile(
|
||||
|
|
|
@ -31,6 +31,24 @@ class TestMySQL(Validator):
|
|||
"mysql": "_utf8mb4 'hola'",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"N 'some text'",
|
||||
read={
|
||||
"mysql": "N'some text'",
|
||||
},
|
||||
write={
|
||||
"mysql": "N 'some text'",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"_latin1 x'4D7953514C'",
|
||||
read={
|
||||
"mysql": "_latin1 X'4D7953514C'",
|
||||
},
|
||||
write={
|
||||
"mysql": "_latin1 x'4D7953514C'",
|
||||
},
|
||||
)
|
||||
|
||||
def test_hexadecimal_literal(self):
|
||||
self.validate_all(
|
||||
|
|
|
@ -69,6 +69,8 @@ class TestPostgres(Validator):
|
|||
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
|
||||
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
|
||||
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
|
||||
self.validate_identity("SELECT e'\\xDEADBEEF'")
|
||||
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
|
||||
|
||||
self.validate_all(
|
||||
"CREATE TABLE x (a UUID, b BYTEA)",
|
||||
|
@ -204,3 +206,11 @@ class TestPostgres(Validator):
|
|||
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""",
|
||||
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT $$a$$",
|
||||
write={"postgres": "SELECT 'a'"},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT $$Dianne's horse$$",
|
||||
write={"postgres": "SELECT 'Dianne''s horse'"},
|
||||
)
|
||||
|
|
|
@ -321,7 +321,7 @@ class TestPresto(Validator):
|
|||
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
|
||||
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
|
||||
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
|
||||
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
|
||||
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
@ -329,7 +329,7 @@ class TestPresto(Validator):
|
|||
write={
|
||||
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
|
||||
"hive": UnsupportedError,
|
||||
"spark": UnsupportedError,
|
||||
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
|
|
|
@ -65,7 +65,7 @@ class TestSnowflake(Validator):
|
|||
self.validate_all(
|
||||
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
|
||||
write={
|
||||
"bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')",
|
||||
"bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')",
|
||||
"snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')",
|
||||
"spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')",
|
||||
},
|
||||
|
@ -73,16 +73,17 @@ class TestSnowflake(Validator):
|
|||
self.validate_all(
|
||||
"SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
|
||||
read={
|
||||
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
|
||||
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
|
||||
"duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
|
||||
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
|
||||
},
|
||||
write={
|
||||
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
|
||||
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
|
||||
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
|
||||
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
|
||||
},
|
||||
)
|
||||
|
||||
self.validate_all(
|
||||
"SELECT IFF(TRUE, 'true', 'false')",
|
||||
write={
|
||||
|
@ -240,11 +241,25 @@ class TestSnowflake(Validator):
|
|||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_PART(month FROM a::DATETIME)",
|
||||
"SELECT DATE_PART(month, a::DATETIME)",
|
||||
write={
|
||||
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_PART(epoch_second, foo) as ddate from table_name",
|
||||
write={
|
||||
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name",
|
||||
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name",
|
||||
write={
|
||||
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name",
|
||||
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
|
||||
},
|
||||
)
|
||||
|
||||
def test_semi_structured_types(self):
|
||||
self.validate_identity("SELECT CAST(a AS VARIANT)")
|
||||
|
|
|
@ -45,3 +45,29 @@ class TestTSQL(Validator):
|
|||
"tsql": "CAST(x AS DATETIME2)",
|
||||
},
|
||||
)
|
||||
|
||||
def test_charindex(self):
|
||||
self.validate_all(
|
||||
"CHARINDEX(x, y, 9)",
|
||||
write={
|
||||
"spark": "LOCATE(x, y, 9)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CHARINDEX(x, y)",
|
||||
write={
|
||||
"spark": "LOCATE(x, y)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CHARINDEX('sub', 'testsubstring', 3)",
|
||||
write={
|
||||
"spark": "LOCATE('sub', 'testsubstring', 3)",
|
||||
},
|
||||
)
|
||||
self.validate_all(
|
||||
"CHARINDEX('sub', 'testsubstring')",
|
||||
write={
|
||||
"spark": "LOCATE('sub', 'testsubstring')",
|
||||
},
|
||||
)
|
||||
|
|
7
tests/fixtures/identity.sql
vendored
7
tests/fixtures/identity.sql
vendored
|
@ -513,6 +513,8 @@ ALTER TYPE electronic_mail RENAME TO email
|
|||
ANALYZE a.y
|
||||
DELETE FROM x WHERE y > 1
|
||||
DELETE FROM y
|
||||
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
|
||||
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
|
||||
DROP TABLE a
|
||||
DROP TABLE a.b
|
||||
DROP TABLE IF EXISTS a
|
||||
|
@ -563,3 +565,8 @@ WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
|
|||
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
|
||||
SELECT ((SELECT 1) + 1)
|
||||
SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES
|
||||
SELECT * FROM (table1 AS t1 LEFT JOIN table2 AS t2 ON 1 = 1)
|
||||
SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
|
||||
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
|
||||
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
|
||||
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
|
||||
|
|
24
tests/fixtures/optimizer/merge_subqueries.sql
vendored
24
tests/fixtures/optimizer/merge_subqueries.sql
vendored
|
@ -287,3 +287,27 @@ SELECT
|
|||
FROM
|
||||
t1;
|
||||
SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x;
|
||||
|
||||
# title: Values Test
|
||||
# dialect: spark
|
||||
WITH t1 AS (
|
||||
SELECT
|
||||
a1.cola
|
||||
FROM
|
||||
VALUES (1) AS a1(cola)
|
||||
), t2 AS (
|
||||
SELECT
|
||||
a2.cola
|
||||
FROM
|
||||
VALUES (1) AS a2(cola)
|
||||
)
|
||||
SELECT /*+ BROADCAST(t2) */
|
||||
t1.cola,
|
||||
t2.cola,
|
||||
FROM
|
||||
t1
|
||||
JOIN
|
||||
t2
|
||||
ON
|
||||
t1.cola = t2.cola;
|
||||
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;
|
||||
|
|
|
@ -33,3 +33,6 @@ SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.
|
|||
|
||||
with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
|
||||
WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
|
||||
|
||||
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;
|
||||
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;
|
||||
|
|
|
@ -22,6 +22,9 @@ SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_
|
|||
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
|
||||
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0";
|
||||
|
||||
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1;
|
||||
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1;
|
||||
|
||||
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
|
||||
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
|
||||
|
||||
|
|
3
tests/fixtures/optimizer/qualify_columns.sql
vendored
3
tests/fixtures/optimizer/qualify_columns.sql
vendored
|
@ -72,6 +72,9 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
|
|||
SELECT a FROM x ORDER BY b;
|
||||
SELECT x.a AS a FROM x AS x ORDER BY x.b;
|
||||
|
||||
SELECT SUM(a) AS a FROM x ORDER BY SUM(a);
|
||||
SELECT SUM(x.a) AS a FROM x AS x ORDER BY SUM(x.a);
|
||||
|
||||
# dialect: bigquery
|
||||
SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1;
|
||||
SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1;
|
||||
|
|
|
@ -53,6 +53,8 @@ def string_to_bool(string):
|
|||
return string and string.lower() in ("true", "1")
|
||||
|
||||
|
||||
SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower())
|
||||
|
||||
TPCH_SCHEMA = {
|
||||
"lineitem": {
|
||||
"l_orderkey": "uint64",
|
||||
|
|
|
@ -7,11 +7,17 @@ from pandas.testing import assert_frame_equal
|
|||
from sqlglot import exp, parse_one
|
||||
from sqlglot.executor import execute
|
||||
from sqlglot.executor.python import Python
|
||||
from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs
|
||||
from tests.helpers import (
|
||||
FIXTURES_DIR,
|
||||
SKIP_INTEGRATION,
|
||||
TPCH_SCHEMA,
|
||||
load_sql_fixture_pairs,
|
||||
)
|
||||
|
||||
DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
|
||||
|
||||
|
||||
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
|
||||
class TestExecutor(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
|
@ -123,13 +123,16 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
|
||||
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
|
||||
|
||||
def test_table(self):
|
||||
self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table))
|
||||
|
||||
def test_replace_tables(self):
|
||||
self.assertEqual(
|
||||
exp.replace_tables(
|
||||
parse_one("select * from a join b join c.a join d.a join e.a"),
|
||||
parse_one("select * from a AS a join b join c.a join d.a join e.a"),
|
||||
{"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
|
||||
).sql(),
|
||||
'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
|
||||
"SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a",
|
||||
)
|
||||
|
||||
def test_named_selects(self):
|
||||
|
@ -495,11 +498,15 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(exp.convert(value).sql(), expected)
|
||||
|
||||
def test_annotation_alias(self):
|
||||
expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
|
||||
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
|
||||
expression = parse_one(sql)
|
||||
self.assertEqual(
|
||||
[e.alias_or_name for e in expression.expressions],
|
||||
["a", "B", "c", "D"],
|
||||
)
|
||||
self.assertEqual(expression.sql(), sql)
|
||||
self.assertEqual(expression.expressions[2].name, "comment")
|
||||
self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D")
|
||||
|
||||
def test_to_table(self):
|
||||
table_only = exp.to_table("table_name")
|
||||
|
@ -514,6 +521,18 @@ class TestExpressions(unittest.TestCase):
|
|||
self.assertEqual(catalog_db_and_table.name, "table_name")
|
||||
self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db"))
|
||||
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
|
||||
with self.assertRaises(ValueError):
|
||||
exp.to_table(1)
|
||||
|
||||
def test_to_column(self):
|
||||
column_only = exp.to_column("column_name")
|
||||
self.assertEqual(column_only.name, "column_name")
|
||||
self.assertIsNone(column_only.args.get("table"))
|
||||
table_and_column = exp.to_column("table_name.column_name")
|
||||
self.assertEqual(table_and_column.name, "column_name")
|
||||
self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name"))
|
||||
with self.assertRaises(ValueError):
|
||||
exp.to_column(1)
|
||||
|
||||
def test_union(self):
|
||||
expression = parse_one("SELECT cola, colb UNION SELECT colx, coly")
|
||||
|
|
|
@ -5,11 +5,11 @@ import duckdb
|
|||
from pandas.testing import assert_frame_equal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp, optimizer, parse_one, table
|
||||
from sqlglot import exp, optimizer, parse_one
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
|
||||
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.schema import MappingSchema
|
||||
from tests.helpers import (
|
||||
TPCH_SCHEMA,
|
||||
load_sql_fixture_pairs,
|
||||
|
@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase):
|
|||
CREATE TABLE x (a INT, b INT);
|
||||
CREATE TABLE y (b INT, c INT);
|
||||
CREATE TABLE z (b INT, c INT);
|
||||
|
||||
|
||||
INSERT INTO x VALUES (1, 1);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (2, 2);
|
||||
INSERT INTO x VALUES (3, 3);
|
||||
INSERT INTO x VALUES (null, null);
|
||||
|
||||
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (2, 2);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
INSERT INTO y VALUES (null, null);
|
||||
|
||||
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (3, 3);
|
||||
INSERT INTO y VALUES (4, 4);
|
||||
|
@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase):
|
|||
|
||||
with self.subTest(title):
|
||||
self.assertEqual(
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
expected,
|
||||
optimized.sql(pretty=pretty, dialect=dialect),
|
||||
)
|
||||
|
||||
should_execute = meta.get("execute")
|
||||
|
@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase):
|
|||
def test_tpch(self):
|
||||
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
|
||||
|
||||
def test_schema(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.column_names(
|
||||
table(
|
||||
"x",
|
||||
)
|
||||
),
|
||||
["a"],
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2"))
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
schema = ensure_schema(
|
||||
MappingSchema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x")), ["a"])
|
||||
|
||||
with self.assertRaises(OptimizeError):
|
||||
ensure_schema({})
|
||||
|
||||
def test_file_schema(self):
|
||||
expression = parse_one(
|
||||
"""
|
||||
|
@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
SELECT x.b FROM x
|
||||
), r AS (
|
||||
SELECT y.b FROM y
|
||||
), z as (
|
||||
SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb)
|
||||
)
|
||||
SELECT
|
||||
r.b,
|
||||
|
@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
|
|||
"""
|
||||
expression = parse_one(sql)
|
||||
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
|
||||
self.assertEqual(len(scopes), 5)
|
||||
self.assertEqual(len(scopes), 7)
|
||||
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
|
||||
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
|
||||
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
|
||||
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
|
||||
self.assertEqual(
|
||||
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
|
||||
)
|
||||
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
|
||||
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
|
||||
self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql())
|
||||
|
||||
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
|
||||
self.assertEqual(len(scopes[4].columns), 6)
|
||||
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
|
||||
self.assertEqual(scopes[4].source_columns("q"), [])
|
||||
self.assertEqual(len(scopes[4].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
|
||||
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
|
||||
self.assertEqual(len(scopes[6].columns), 6)
|
||||
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
|
||||
self.assertEqual(scopes[6].source_columns("q"), [])
|
||||
self.assertEqual(len(scopes[6].source_columns("r")), 2)
|
||||
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
|
||||
|
||||
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
|
||||
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")
|
||||
|
|
|
@ -81,7 +81,7 @@ class TestParser(unittest.TestCase):
|
|||
self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint)
|
||||
self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint)
|
||||
|
||||
default = Parser()
|
||||
default = Parser(error_level=ErrorLevel.RAISE)
|
||||
self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint)
|
||||
default.expression(exp.Hint, y="")
|
||||
default.expression(exp.Hint)
|
||||
|
@ -139,12 +139,12 @@ class TestParser(unittest.TestCase):
|
|||
)
|
||||
|
||||
assert expression.expressions[0].name == "annotation1"
|
||||
assert expression.expressions[1].name == "annotation2:testing "
|
||||
assert expression.expressions[1].name == "annotation2:testing"
|
||||
assert expression.expressions[2].name == "test#annotation"
|
||||
assert expression.expressions[3].name == "annotation3"
|
||||
assert expression.expressions[4].name == "annotation4"
|
||||
assert expression.expressions[5].name == ""
|
||||
assert expression.expressions[6].name == " space"
|
||||
assert expression.expressions[6].name == "space"
|
||||
|
||||
def test_pretty_config_override(self):
|
||||
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
|
||||
|
|
290
tests/test_schema.py
Normal file
290
tests/test_schema.py
Normal file
|
@ -0,0 +1,290 @@
|
|||
import unittest
|
||||
|
||||
from sqlglot import table
|
||||
from sqlglot.dataframe.sql import types as df_types
|
||||
from sqlglot.schema import MappingSchema, ensure_schema
|
||||
|
||||
|
||||
class TestSchema(unittest.TestCase):
|
||||
def test_schema(self):
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
schema.column_names(
|
||||
table(
|
||||
"x",
|
||||
)
|
||||
),
|
||||
["a"],
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y"), {"b": "string"})
|
||||
schema_with_y = {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema.schema, schema_with_y)
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema.add_table(table("z"), {"c": "string"})
|
||||
self.assertEqual(schema.schema, schema_with_y)
|
||||
self.assertEqual(
|
||||
new_schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"z": {
|
||||
"c": "string",
|
||||
},
|
||||
},
|
||||
)
|
||||
schema.add_table(table("m"), {"d": "string"})
|
||||
schema.add_table(table("n"), {"e": "string"})
|
||||
schema_with_m_n = {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"m": {
|
||||
"d": "string",
|
||||
},
|
||||
"n": {
|
||||
"e": "string",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema.schema, schema_with_m_n)
|
||||
new_schema = schema.copy()
|
||||
new_schema.add_table(table("o"), {"f": "string"})
|
||||
new_schema.add_table(table("p"), {"g": "string"})
|
||||
self.assertEqual(schema.schema, schema_with_m_n)
|
||||
self.assertEqual(
|
||||
new_schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
"m": {
|
||||
"d": "string",
|
||||
},
|
||||
"n": {
|
||||
"e": "string",
|
||||
},
|
||||
"o": {
|
||||
"f": "string",
|
||||
},
|
||||
"p": {
|
||||
"g": "string",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y", db="db"), {"b": "string"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db", catalog="c2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x", db="db2"))
|
||||
with self.assertRaises(ValueError):
|
||||
schema.column_names(table("x2", db="db"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("x"), {"b": "string"})
|
||||
with self.assertRaises(ValueError):
|
||||
schema.add_table(table("x", db="db"), {"b": "string"})
|
||||
|
||||
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
},
|
||||
},
|
||||
"db2": {
|
||||
"z": {
|
||||
"c": "string",
|
||||
"d": "int",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"c": {
|
||||
"db": {
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
},
|
||||
"y": {
|
||||
"a": "string",
|
||||
"b": "int",
|
||||
},
|
||||
},
|
||||
"db2": {
|
||||
"z": {
|
||||
"c": "string",
|
||||
"d": "int",
|
||||
}
|
||||
},
|
||||
},
|
||||
"c2": {
|
||||
"db2": {
|
||||
"m": {
|
||||
"e": "string",
|
||||
"f": "int",
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = ensure_schema(
|
||||
{
|
||||
"x": {
|
||||
"a": "uint64",
|
||||
}
|
||||
}
|
||||
)
|
||||
self.assertEqual(schema.column_names(table("x")), ["a"])
|
||||
|
||||
schema = MappingSchema()
|
||||
schema.add_table(table("x"), {"a": "string"})
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "string",
|
||||
}
|
||||
},
|
||||
)
|
||||
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
|
||||
self.assertEqual(
|
||||
schema.schema,
|
||||
{
|
||||
"x": {
|
||||
"a": "string",
|
||||
},
|
||||
"y": {
|
||||
"b": "string",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_schema_add_table_with_and_without_mapping(self):
|
||||
schema = MappingSchema()
|
||||
schema.add_table("test")
|
||||
self.assertEqual(schema.column_names("test"), [])
|
||||
schema.add_table("test", {"x": "string"})
|
||||
self.assertEqual(schema.column_names("test"), ["x"])
|
||||
schema.add_table("test", {"x": "string", "y": "int"})
|
||||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
||||
schema.add_table("test")
|
||||
self.assertEqual(schema.column_names("test"), ["x", "y"])
|
Loading…
Add table
Reference in a new issue