1
0
Fork 0

Adding upstream version 9.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:47:39 +01:00
parent 768d386bf5
commit fca0265317
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
87 changed files with 7994 additions and 421 deletions

View file

@ -20,7 +20,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
- name: Run checks (linter, code style, tests)
run: |
./run_checks.sh

View file

@ -1,6 +1,21 @@
Changelog
=========
v9.0.0
------
Changes:
- Breaking : Changed AST hierarchy of exp.Table with exp.Alias. Before Tables were children's of their aliases, but in order to simplify the AST and fix some issues, Tables now have an alias property.
v8.0.0
------
Changes:
- Breaking : New add\_table method in Schema ABC.
- New: SQLGlot now supports the [PySpark](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dataframe) dataframe API. This is still relatively experimental.
v7.1.0
------

View file

@ -317,6 +317,7 @@ Dialect["custom"]
## Run Tests and Lint
```
pip install -r requirements.txt
# set `SKIP_INTEGRATION=1` to skip integration tests
./run_checks.sh
```

View file

@ -2,5 +2,7 @@ autoflake
black
duckdb
isort
mypy
pandas
pyspark
python-dateutil

View file

@ -11,4 +11,5 @@ python -m autoflake -i -r ${RETURN_ERROR_CODE} \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
python -m mypy sqlglot tests
python -m unittest

15
setup.cfg Normal file
View file

@ -0,0 +1,15 @@
[mypy]
disallow_untyped_calls = False
no_implicit_optional = True
[mypy-sqlglot.*]
ignore_errors = True
[mypy-sqlglot.dataframe.*]
ignore_errors = False
[mypy-tests.*]
ignore_errors = True
[mypy-tests.dataframe.*]
ignore_errors = False

View file

@ -21,12 +21,15 @@ from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "7.1.3"
__version__ = "9.0.1"
pretty = False
schema = MappingSchema()
def parse(sql, read=None, **opts):
"""

View file

@ -40,8 +40,8 @@ parser.add_argument(
"--error-level",
dest="error_level",
type=str,
default="RAISE",
help="IGNORE, WARN, RAISE (default)",
default="IMMEDIATE",
help="IGNORE, WARN, RAISE, IMMEDIATE (default)",
)

224
sqlglot/dataframe/README.md Normal file
View file

@ -0,0 +1,224 @@
# PySpark DataFrame SQL Generator
This is a drop-in replacement for the PysPark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/).
Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported.
# How to use
## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`
* The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name
* Ex: {'cola': 'string', 'colb': 'int'}
* PySpark DataFrame `StructType` similar to when using `createDataFrame`
* Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])`
* A string of names and types similar to what is supported in `createDataFrame`
* Ex: `cola: STRING, colb: INT`
* [Not Recommended] A list of string column names without type
* Ex: ['cola', 'colb']
* The lack of types may limit functionality in future releases
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects
* Ex: `.sql(pretty=True, dialect='bigquery')`
## Examples
```python
import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
sqlglot.schema.add_table('employee', {
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
}) # Register the table structure prior to reading from the table
spark = SparkSession()
df = (
spark
.table('employee')
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
)
print(df.sql(pretty=True)) # Spark will be the dialect used by default
```
Output:
```sparksql
SELECT
`employee`.`age` AS `age`,
COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees`
FROM `employee` AS `employee`
GROUP BY
`employee`.`age`
```
## Registering Custom Schema Class
The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`.
```python
import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
from sqlglot.schema import Schema
class ExternalSchema(Schema):
...
sqlglot.schema = ExternalSchema()
spark = SparkSession()
df = (
spark
.table('employee')
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
)
print(df.sql(pretty=True))
```
## Example Implementations
### Bigquery
```python
from google.cloud import bigquery
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql import functions as F
client = bigquery.Client()
data = [
(1, "Jack", "Shephard", 34),
(2, "John", "Locke", 48),
(3, "Kate", "Austen", 34),
(4, "Claire", "Littleton", 22),
(5, "Hugo", "Reyes", 26),
]
schema = types.StructType([
types.StructField('employee_id', types.IntegerType(), False),
types.StructField('fname', types.StringType(), False),
types.StructField('lname', types.StringType(), False),
types.StructField('age', types.IntegerType(), False),
])
sql_statements = (
SparkSession()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
.sql(dialect="bigquery")
)
result = None
for sql in sql_statements:
result = client.query(sql)
assert result is not None
for row in client.query(result):
print(f"Age: {row['age']}, Num Employees: {row['num_employees']}")
```
### Snowflake
```python
import os
import snowflake.connector
from sqlglot.dataframe.session import SparkSession
from sqlglot.dataframe import types
from sqlglot.dataframe import functions as F
ctx = snowflake.connector.connect(
user=os.environ["SNOWFLAKE_USER"],
password=os.environ["SNOWFLAKE_PASS"],
account=os.environ["SNOWFLAKE_ACCOUNT"]
)
cs = ctx.cursor()
data = [
(1, "Jack", "Shephard", 34),
(2, "John", "Locke", 48),
(3, "Kate", "Austen", 34),
(4, "Claire", "Littleton", 22),
(5, "Hugo", "Reyes", 26),
]
schema = types.StructType([
types.StructField('employee_id', types.IntegerType(), False),
types.StructField('fname', types.StringType(), False),
types.StructField('lname', types.StringType(), False),
types.StructField('age', types.IntegerType(), False),
])
sql_statements = (
SparkSession()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("lname")).alias("num_employees"))
.sql(dialect="snowflake")
)
try:
for sql in sql_statements:
cs.execute(sql)
results = cs.fetchall()
for row in results:
print(f"Age: {row[0]}, Num Employees: {row[1]}")
finally:
cs.close()
ctx.close()
```
### Spark
```python
from pyspark.sql.session import SparkSession as PySparkSession
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql import functions as F
data = [
(1, "Jack", "Shephard", 34),
(2, "John", "Locke", 48),
(3, "Kate", "Austen", 34),
(4, "Claire", "Littleton", 22),
(5, "Hugo", "Reyes", 26),
]
schema = types.StructType([
types.StructField('employee_id', types.IntegerType(), False),
types.StructField('fname', types.StringType(), False),
types.StructField('lname', types.StringType(), False),
types.StructField('age', types.IntegerType(), False),
])
sql_statements = (
SparkSession()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
.sql(dialect="bigquery")
)
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
df = None
for sql in sql_statements:
df = pyspark.sql(sql)
assert df is not None
df.show()
```
# Unsupportable Operations
Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported.

View file

View file

@ -0,0 +1,18 @@
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions
from sqlglot.dataframe.sql.group import GroupedData
from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql.window import Window, WindowSpec
__all__ = [
"SparkSession",
"DataFrame",
"GroupedData",
"Column",
"DataFrameNaFunctions",
"Window",
"WindowSpec",
"DataFrameReader",
"DataFrameWriter",
]

View file

@ -0,0 +1,20 @@
from __future__ import annotations
import datetime
import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
)
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])

View file

@ -0,0 +1,295 @@
from __future__ import annotations
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
from sqlglot.helper import flatten
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrLiteral
from sqlglot.dataframe.sql.window import WindowSpec
class Column:
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
if isinstance(expression, Column):
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
def __repr__(self):
return repr(self.expression)
def __hash__(self):
return hash(self.expression)
def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore
return self.binary_op(exp.EQ, other)
def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore
return self.binary_op(exp.NEQ, other)
def __gt__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.GT, other)
def __ge__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.GTE, other)
def __lt__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.LT, other)
def __le__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.LTE, other)
def __and__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.And, other)
def __or__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Or, other)
def __mod__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Mod, other)
def __add__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Add, other)
def __sub__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Sub, other)
def __mul__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Mul, other)
def __truediv__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Div, other)
def __div__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Div, other)
def __neg__(self) -> Column:
return self.unary_op(exp.Neg)
def __radd__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Add, other)
def __rsub__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Sub, other)
def __rmul__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Mul, other)
def __rdiv__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Div, other)
def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Div, other)
def __rmod__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Mod, other)
def __pow__(self, power: ColumnOrLiteral, modulo=None):
return Column(exp.Pow(this=self.expression, power=Column(power).expression))
def __rpow__(self, power: ColumnOrLiteral):
return Column(exp.Pow(this=Column(power).expression, power=self.expression))
def __invert__(self):
return self.unary_op(exp.Not)
def __rand__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.And, other)
def __ror__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Or, other)
@classmethod
def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
return cls(value)
@classmethod
def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]:
return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args]
@classmethod
def _lit(cls, value: ColumnOrLiteral) -> Column:
if isinstance(value, dict):
columns = [cls._lit(v).alias(k).expression for k, v in value.items()]
return cls(exp.Struct(expressions=columns))
return cls(exp.convert(value))
@classmethod
def invoke_anonymous_function(
cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral]
) -> Column:
columns = [] if column is None else [cls.ensure_col(column)]
column_args = [cls.ensure_col(arg) for arg in args]
expressions = [x.expression for x in columns + column_args]
new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions)
return Column(new_expression)
@classmethod
def invoke_expression_over_column(
cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
new_expression = (
callable_expression(**kwargs)
if ensured_column is None
else callable_expression(this=ensured_column.column_expression, **kwargs)
)
return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs))
@property
def is_alias(self):
return isinstance(self.expression, exp.Alias)
@property
def is_column(self):
return isinstance(self.expression, exp.Column)
@property
def column_expression(self) -> exp.Column:
return self.expression.unalias()
@property
def alias_or_name(self) -> str:
return self.expression.alias_or_name
@classmethod
def ensure_literal(cls, value) -> Column:
from sqlglot.dataframe.sql.functions import lit
if isinstance(value, cls):
value = value.expression
if not isinstance(value, exp.Literal):
return lit(value)
return Column(value)
def copy(self) -> Column:
return Column(self.expression.copy())
def set_table_name(self, table_name: str, copy=False) -> Column:
expression = self.expression.copy() if copy else self.expression
expression.set("table", exp.to_identifier(table_name))
return Column(expression)
def sql(self, **kwargs) -> Column:
return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column:
new_expression = exp.alias_(self.column_expression, name)
return Column(new_expression)
def asc(self) -> Column:
new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True)
return Column(new_expression)
def desc(self) -> Column:
new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False)
return Column(new_expression)
asc_nulls_first = asc
def asc_nulls_last(self) -> Column:
new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False)
return Column(new_expression)
def desc_nulls_first(self) -> Column:
new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True)
return Column(new_expression)
desc_nulls_last = desc
def when(self, condition: Column, value: t.Any) -> Column:
from sqlglot.dataframe.sql.functions import when
column_with_if = when(condition, value)
if not isinstance(self.expression, exp.Case):
return column_with_if
new_column = self.copy()
new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"])
return new_column
def otherwise(self, value: t.Any) -> Column:
from sqlglot.dataframe.sql.functions import lit
true_value = value if isinstance(value, Column) else lit(value)
new_column = self.copy()
new_column.expression.set("default", true_value.column_expression)
return new_column
def isNull(self) -> Column:
new_expression = exp.Is(this=self.column_expression, expression=exp.Null())
return Column(new_expression)
def isNotNull(self) -> Column:
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
return Column(new_expression)
def cast(self, dataType: t.Union[str, DataType]):
"""
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
new_expression = exp.Cast(this=self.column_expression, to=dataType)
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value
return self.invoke_anonymous_function(self, "STARTSWITH", value)
def endswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value
return self.invoke_anonymous_function(self, "ENDSWITH", value)
def rlike(self, regexp: str) -> Column:
return self.invoke_expression_over_column(
column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression
)
def like(self, other: str):
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
def ilike(self, other: str):
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
length = self._lit(length) if not isinstance(length, Column) else length
return Column.invoke_expression_over_column(
self, exp.Substring, start=startPos.expression, length=length.expression
)
def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]):
columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [self._lit(x).expression for x in columns]
return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore
def between(
self,
lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral],
) -> Column:
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
return Column(
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
)
def over(self, window: WindowSpec) -> Column:
window_expression = window.expression.copy()
window_expression.set("this", self.column_expression)
return Column(window_expression)

View file

@ -0,0 +1,730 @@
from __future__ import annotations
import functools
import typing as t
import zlib
from copy import copy
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.group import GroupedData
from sqlglot.dataframe.sql.normalize import normalize
from sqlglot.dataframe.sql.operations import Operation, operation
from sqlglot.dataframe.sql.readwriter import DataFrameWriter
from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict
from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
from sqlglot.dataframe.sql.session import SparkSession
JOIN_HINTS = {
"BROADCAST",
"BROADCASTJOIN",
"MAPJOIN",
"MERGE",
"SHUFFLEMERGE",
"MERGEJOIN",
"SHUFFLE_HASH",
"SHUFFLE_REPLICATE_NL",
}
class DataFrame:
def __init__(
self,
spark: SparkSession,
expression: exp.Select,
branch_id: t.Optional[str] = None,
sequence_id: t.Optional[str] = None,
last_op: Operation = Operation.INIT,
pending_hints: t.Optional[t.List[exp.Expression]] = None,
output_expression_container: t.Optional[OutputExpressionContainer] = None,
**kwargs,
):
self.spark = spark
self.expression = expression
self.branch_id = branch_id or self.spark._random_branch_id
self.sequence_id = sequence_id or self.spark._random_sequence_id
self.last_op = last_op
self.pending_hints = pending_hints or []
self.output_expression_container = output_expression_container or exp.Select()
def __getattr__(self, column_name: str) -> Column:
return self[column_name]
def __getitem__(self, column_name: str) -> Column:
column_name = f"{self.branch_id}.{column_name}"
return Column(column_name)
def __copy__(self):
return self.copy()
@property
def sparkSession(self):
return self.spark
@property
def write(self):
return DataFrameWriter(self)
@property
def latest_cte_name(self) -> str:
if not self.expression.ctes:
from_exp = self.expression.args["from"]
if from_exp.alias_or_name:
return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias)
if not table_alias:
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
return table_alias.alias_or_name
return self.expression.ctes[-1].alias
@property
def pending_join_hints(self):
return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
@property
def pending_partition_hints(self):
return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
@property
def columns(self) -> t.List[str]:
return self.expression.named_selects
@property
def na(self) -> DataFrameNaFunctions:
return DataFrameNaFunctions(self)
def _replace_cte_names_with_hashes(self, expression: exp.Select):
expression = expression.copy()
ctes = expression.ctes
replacement_mapping = {}
for cte in ctes:
old_name_id = cte.args["alias"].this
new_hashed_id = exp.to_identifier(
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
)
replacement_mapping[old_name_id] = new_hashed_id
cte.set("alias", exp.TableAlias(this=new_hashed_id))
expression = expression.transform(replace_id_value, replacement_mapping)
return expression
def _create_cte_from_expression(
self,
expression: exp.Expression,
branch_id: t.Optional[str] = None,
sequence_id: t.Optional[str] = None,
**kwargs,
) -> t.Tuple[exp.CTE, str]:
name = self.spark._random_name
expression_to_cte = expression.copy()
expression_to_cte.set("with", None)
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
cte.set("branch_id", branch_id or self.branch_id)
cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name
def _ensure_list_of_columns(
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
) -> t.List[Column]:
columns = ensure_list(cols)
columns = Column.ensure_cols(columns)
return columns
def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols)
normalize(self.spark, self.expression, cols)
return cols
def _ensure_and_normalize_col(self, col):
col = Column.ensure_col(col)
normalize(self.spark, self.expression, col)
return col
def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
sel_columns = df._get_outer_select_columns(cte_expression)
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame:
df = self.copy()
if not self.pending_hints:
return df
expression = df.expression
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
for hint in df.pending_partition_hints:
hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint)
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
if join_aliases:
for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
matching_ctes = [
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
]
for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases:
sequence_id_expression.set("this", matching_cte.args["alias"].this)
df.pending_hints.remove(hint)
break
hint_expression.args.get("expressions").append(hint)
if hint_expression.expressions:
expression.set("hint", hint_expression)
return df
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper()
hint_expression = (
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
if hint_name in JOIN_HINTS
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
)
new_df = self.copy()
new_df.pending_hints.append(hint_expression)
return new_df
def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
other_df = other._convert_leaf_to_cte()
base_expression = self.expression.copy()
base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
all_ctes = base_expression.ctes
other_df.expression.set("with", None)
base_expression.set("with", None)
operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
operation.set("with", exp.With(expressions=all_ctes))
return self.copy(expression=operation)._convert_leaf_to_cte()
def _cache(self, storage_level: str):
df = self._convert_leaf_to_cte()
df.expression.ctes[-1].set("cache_storage_level", storage_level)
return df
@classmethod
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
expression = expression.copy()
with_expression = expression.args.get("with")
if with_expression:
existing_ctes = with_expression.expressions
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
for cte in ctes:
if cte.alias_or_name not in existsing_cte_names:
existing_ctes.append(cte)
else:
existing_ctes = ctes
expression.set("with", exp.With(expressions=existing_ctes))
return expression
@classmethod
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
expression = item.expression if isinstance(item, DataFrame) else item
return [Column(x) for x in expression.find(exp.Select).expressions]
@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):
value = expression.sql(dialect="spark").encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
def _get_select_expressions(
self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level")
if cache_storage_level:
select_expression = cte.this.copy()
select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
select_expression.set("cte_alias_name", cte.alias_or_name)
select_expression.set("cache_storage_level", cache_storage_level)
select_expressions.append((exp.Cache, select_expression))
else:
main_select_ctes.append(cte)
main_select = self.expression.copy()
if main_select_ctes:
main_select.set("with", exp.With(expressions=main_select_ctes))
expression_select_pair = (type(self.output_expression_container), main_select)
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = optimize_func(select_expression)
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"]
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level),
]
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
expression = df.output_expression_container.copy()
expression.set("expression", select_expression)
elif expression_type == exp.Insert:
expression = df.output_expression_container.copy()
select_without_ctes = select_expression.copy()
select_without_ctes.set("with", None)
expression.set("expression", select_without_ctes)
if select_expression.ctes:
expression.set("with", exp.With(expressions=select_expression.ctes))
elif expression_type == exp.Select:
expression = select_expression
else:
raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression)
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@operation(Operation.SELECT)
def select(self, *cols, **kwargs) -> DataFrame:
cols = self._ensure_and_normalize_cols(cols)
kwargs["append"] = kwargs.get("append", False)
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols:
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
cte
for cte in self.expression.ctes
if cte.alias_or_name in cte_names_in_join
and ambiguous_col.alias_or_name in cte.this.named_selects
]
# If the select column does not specify a table and there is a join
# then we assume they are referring to the left table
if len(ctes_with_column) > 1:
table_identifier = self.expression.args["from"].args["expressions"][0].this
else:
table_identifier = ctes_with_column[0].args["alias"].this
ambiguous_col.expression.set("table", table_identifier)
expression = self.expression.select(*[x.expression for x in cols], **kwargs)
qualify_columns(expression, sqlglot.schema)
return self.copy(expression=expression, **kwargs)
@operation(Operation.NO_OP)
def alias(self, name: str, **kwargs) -> DataFrame:
new_sequence_id = self.spark._random_sequence_id
df = self.copy()
for join_hint in df.pending_join_hints:
for expression in join_hint.expressions:
if expression.alias_or_name == self.sequence_id:
expression.set("this", Column.ensure_col(new_sequence_id).expression)
df.spark._add_alias_to_mapping(name, new_sequence_id)
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
@operation(Operation.WHERE)
def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
col = self._ensure_and_normalize_col(column)
return self.copy(expression=self.expression.where(col.expression))
filter = where
@operation(Operation.GROUP_BY)
def groupBy(self, *cols, **kwargs) -> GroupedData:
columns = self._ensure_and_normalize_cols(cols)
return GroupedData(self, columns, self.last_op)
@operation(Operation.SELECT)
def agg(self, *exprs, **kwargs) -> DataFrame:
cols = self._ensure_and_normalize_cols(exprs)
return self.groupBy().agg(*cols)
@operation(Operation.FROM)
def join(
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
join_clause = functools.reduce(
lambda x, y: x & y,
[
col.copy().set_table_name(pre_join_self_latest_cte_name)
== col.copy().set_table_name(other_df.latest_cte_name)
for col in columns
],
)
else:
if len(columns) > 1:
columns = [functools.reduce(lambda x, y: x & y, columns)]
join_clause = columns[0]
join_columns = [
Column(x).set_table_name(pre_join_self_latest_cte_name)
if i % 2 == 0
else Column(x).set_table_name(other_df.latest_cte_name)
for i, x in enumerate(join_clause.expression.find_all(exp.Column))
]
self_columns = [
column.set_table_name(pre_join_self_latest_cte_name, copy=True)
for column in self._get_outer_select_columns(self)
]
other_columns = [
column.set_table_name(other_df.latest_cte_name, copy=True)
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
column_value_mapping[name]
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
)
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df
@operation(Operation.ORDER_BY)
def orderBy(
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
) -> DataFrame:
"""
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
is unlikely to come up.
"""
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
x
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
if x is not None
]
if ascending is None:
ascending = [True] * len(columns)
elif not isinstance(ascending, list):
ascending = [ascending] * len(columns)
ascending = [bool(x) for i, x in enumerate(ascending)]
assert len(columns) == len(
ascending
), "The length of items in ascending must equal the number of columns provided"
col_and_ascending = list(zip(columns, ascending))
order_by_columns = [
exp.Ordered(this=col.expression, desc=not asc)
if i not in pre_ordered_col_indexes
else columns[i].column_expression
for i, (col, asc) in enumerate(col_and_ascending)
]
return self.copy(expression=self.expression.order_by(*order_by_columns))
sort = orderBy
@operation(Operation.FROM)
def union(self, other: DataFrame) -> DataFrame:
return self._set_operation(exp.Union, other, False)
unionAll = union
@operation(Operation.FROM)
def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
l_columns = self.columns
r_columns = other.columns
if not allowMissingColumns:
l_expressions = l_columns
r_expressions = l_columns
else:
l_expressions = []
r_expressions = []
r_columns_unused = copy(r_columns)
for l_column in l_columns:
l_expressions.append(l_column)
if l_column in r_columns:
r_expressions.append(l_column)
r_columns_unused.remove(l_column)
else:
r_expressions.append(exp.alias_(exp.Null(), l_column))
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column)
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
l_df = self.copy()
if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
return l_df._set_operation(exp.Union, r_df, False)
@operation(Operation.FROM)
def intersect(self, other: DataFrame) -> DataFrame:
return self._set_operation(exp.Intersect, other, True)
@operation(Operation.FROM)
def intersectAll(self, other: DataFrame) -> DataFrame:
return self._set_operation(exp.Intersect, other, False)
@operation(Operation.FROM)
def exceptAll(self, other: DataFrame) -> DataFrame:
return self._set_operation(exp.Except, other, False)
@operation(Operation.SELECT)
def distinct(self) -> DataFrame:
return self.copy(expression=self.expression.distinct())
@operation(Operation.SELECT)
def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
if not subset:
return self.distinct()
column_names = ensure_list(subset)
window = Window.partitionBy(*column_names).orderBy(*column_names)
return (
self.copy()
.withColumn("row_num", F.row_number().over(window))
.where(F.col("row_num") == F.lit(1))
.drop("row_num")
)
@operation(Operation.FROM)
def dropna(
self,
how: str = "any",
thresh: t.Optional[int] = None,
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
) -> DataFrame:
minimum_non_null = thresh or 0 # will be determined later if thresh is null
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
if subset:
null_check_columns = self._ensure_and_normalize_cols(subset)
else:
null_check_columns = all_columns
if thresh is None:
minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
else:
minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
if minimum_num_nulls > len(null_check_columns):
raise RuntimeError(
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
)
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True)
filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
final_df = filtered_df.select(*all_columns)
return final_df
@operation(Operation.FROM)
def fillna(
self,
value: t.Union[ColumnLiterals],
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
) -> DataFrame:
"""
Functionality Difference: If you provide a value to replace a null and that type conflicts
with the type of the column then PySpark will just ignore your replacement.
This will try to cast them to be the same in some cases. So they won't always match.
Best to not mix types so make sure replacement is the same type as the column
Possibility for improvement: Use `typeof` function to get the type of the column
and check if it matches the type of the value provided. If not then make it null.
"""
from sqlglot.dataframe.sql.functions import lit
values = None
columns = None
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
if isinstance(value, dict):
values = value.values()
columns = self._ensure_and_normalize_cols(list(value))
if not columns:
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
if not values:
values = [value] * len(columns)
value_columns = [lit(value) for value in values]
null_replacement_mapping = {
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
for column, value in zip(columns, value_columns)
}
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
new_df = new_df.select(*null_replacement_columns)
return new_df
@operation(Operation.FROM)
def replace(
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
subset: t.Optional[t.Union[str, t.List[str]]] = None,
) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit
old_values = None
subset = ensure_list(subset)
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
if isinstance(to_replace, dict):
old_values = list(to_replace)
new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list"
assert len(to_replace) == len(value), "the replacements and values must be the same length"
old_values = to_replace
new_values = value
else:
old_values = [to_replace] * len(columns)
new_values = [value] * len(columns)
old_values = [lit(value) for value in old_values]
new_values = [lit(value) for value in new_values]
replacement_mapping = {}
for column in columns:
expression = Column(None)
for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
if i == 0:
expression = F.when(column == old_value, new_value)
else:
expression = expression.when(column == old_value, new_value) # type: ignore
replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
column.expression.alias_or_name
)
replacement_mapping = {**all_column_mapping, **replacement_mapping}
replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
new_df = new_df.select(*replacement_columns)
return new_df
@operation(Operation.SELECT)
def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
if existing_col_index:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression
return self.copy(expression=expression)
return self.copy().select(col.alias(colName), append=True)
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy()
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
if isinstance(existing_column, exp.Column):
existing_column.replace(exp.alias_(existing_column.copy(), new))
else:
existing_column.set("alias", exp.to_identifier(new))
return self.copy(expression=expression)
@operation(Operation.SELECT)
def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
all_columns = self._get_outer_select_columns(self.expression)
drop_cols = self._ensure_and_normalize_cols(cols)
new_columns = [
col
for col in all_columns
if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
]
return self.copy().select(*new_columns, append=False)
@operation(Operation.LIMIT)
def limit(self, num: int) -> DataFrame:
return self.copy(expression=self.expression.limit(num))
@operation(Operation.NO_OP)
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters)
parameter_columns = (
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
)
return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
num_partitions = Column.ensure_cols(ensure_list(numPartitions))
columns = self._ensure_and_normalize_cols(cols)
args = num_partitions + columns
return self._hint("repartition", args)
@operation(Operation.NO_OP)
def coalesce(self, numPartitions: int) -> DataFrame:
num_partitions = Column.ensure_cols([numPartitions])
return self._hint("coalesce", num_partitions)
@operation(Operation.NO_OP)
def cache(self) -> DataFrame:
return self._cache(storage_level="MEMORY_AND_DISK")
@operation(Operation.NO_OP)
def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
"""
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
"""
return self._cache(storageLevel)
class DataFrameNaFunctions:
def __init__(self, df: DataFrame):
self.df = df
def drop(
self,
how: str = "any",
thresh: t.Optional[int] = None,
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
) -> DataFrame:
return self.df.dropna(how=how, thresh=thresh, subset=subset)
def fill(
self,
value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
) -> DataFrame:
return self.df.fillna(value=value, subset=subset)
def replace(
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
subset: t.Optional[t.Union[str, t.List[str]]] = None,
) -> DataFrame:
return self.df.replace(to_replace=to_replace, value=value, subset=subset)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,57 @@
from __future__ import annotations
import typing as t
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.operations import Operation, operation
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
class GroupedData:
def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
self._df = df.copy()
self.spark = df.spark
self.last_op = last_op
self.group_by_cols = group_by_cols
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
func_name = func_name.lower()
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
@operation(Operation.SELECT)
def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
columns = (
[Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
if isinstance(exprs[0], dict)
else exprs
)
cols = self._df._ensure_and_normalize_cols(columns)
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
*[x.expression for x in self.group_by_cols + cols], append=False
)
return self._df.copy(expression=expression)
def count(self) -> DataFrame:
return self.agg(F.count("*").alias("count"))
def mean(self, *cols: str) -> DataFrame:
return self.avg(*cols)
def avg(self, *cols: str) -> DataFrame:
return self.agg(*self._get_function_applied_columns("avg", cols))
def max(self, *cols: str) -> DataFrame:
return self.agg(*self._get_function_applied_columns("max", cols))
def min(self, *cols: str) -> DataFrame:
return self.agg(*self._get_function_applied_columns("min", cols))
def sum(self, *cols: str) -> DataFrame:
return self.agg(*self._get_function_applied_columns("sum", cols))
def pivot(self, *cols: str) -> DataFrame:
raise NotImplementedError("Sum distinct is not currently implemented")

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.session import SparkSession
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
expr = ensure_list(expr)
expressions = _ensure_expressions(expr)
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
_set_alias_name(id, cte.alias_or_name)
break
def replace_branch_and_sequence_ids_with_cte_name(
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
if id.alias_or_name in spark.known_ids:
# Check if we have a join and if both the tables in that join share a common branch id
# If so we need to have this reference the left table by default unless the id is a sequence
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
return
for cte in reversed(expression_context.ctes):
if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
_set_alias_name(id, cte.alias_or_name)
return
def _set_alias_name(id: exp.Identifier, name: str):
id.set("this", name)
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
values = ensure_list(values)
results = []
for value in values:
if isinstance(value, str):
results.append(Column.ensure_col(value).expression)
elif isinstance(value, Column):
results.append(value.expression)
elif isinstance(value, exp.Expression):
results.append(value)
else:
raise ValueError(f"Got an invalid type to normalize: {type(value)}")
return results

View file

@ -0,0 +1,53 @@
from __future__ import annotations
import functools
import typing as t
from enum import IntEnum
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.group import GroupedData
class Operation(IntEnum):
INIT = -1
NO_OP = 0
FROM = 1
WHERE = 2
GROUP_BY = 3
HAVING = 4
SELECT = 5
ORDER_BY = 6
LIMIT = 7
def operation(op: Operation):
"""
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
included with the previous operation.
Ex: After a user does a join we want to allow them to select which columns for the different
tables that they want to carry through to the following operation. If we put that join in
a CTE preemptively then the user would not have a chance to select which column they want
in cases where there is overlap in names.
"""
def decorator(func: t.Callable):
@functools.wraps(func)
def wrapper(self: DataFrame, *args, **kwargs):
if self.last_op == Operation.INIT:
self = self._convert_leaf_to_cte()
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
return df
wrapper.__wrapped__ = func # type: ignore
return wrapper
return decorator

View file

@ -0,0 +1,79 @@
from __future__ import annotations
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
class DataFrameReader:
def __init__(self, spark: SparkSession):
self.spark = spark
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
class DataFrameWriter:
def __init__(
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
):
self._df = df
self._spark = spark or df.spark
self._mode = mode
self._by_name = by_name
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
)
def sql(self, **kwargs) -> t.List[str]:
return self._df.sql(**kwargs)
def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
return self.copy(_mode=saveMode)
@property
def byName(self):
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
"overwrite": overwrite,
}
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True)
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)
def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
if format is not None:
raise NotImplementedError("Providing Format in the save as table is not supported")
exists, replace, mode = None, None, mode or str(self._mode)
if mode == "append":
return self.insertInto(name)
if mode == "ignore":
exists = True
if mode == "overwrite":
replace = True
output_expression_container = exp.Create(
this=exp.to_table(name),
kind="TABLE",
exists=exists,
replace=replace,
)
return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))

View file

@ -0,0 +1,148 @@
from __future__ import annotations
import typing as t
import uuid
from collections import defaultdict
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
known_ids: t.ClassVar[t.Set[str]] = set()
known_branch_ids: t.ClassVar[t.Set[str]] = set()
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
def __init__(self):
self.incrementing_id = 1
def __getattr__(self, name: str) -> SparkSession:
return self
def __call__(self, *args, **kwargs) -> SparkSession:
return self
@property
def read(self) -> DataFrameReader:
return DataFrameReader(self)
def table(self, tableName: str) -> DataFrame:
return self.read.table(tableName)
def createDataFrame(
self,
data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
schema: t.Optional[SchemaInput] = None,
samplingRatio: t.Optional[float] = None,
verifySchema: bool = False,
) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
if samplingRatio is not None or verifySchema:
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
if schema is not None and (
not isinstance(schema, (StructType, str, list))
or (isinstance(schema, list) and not isinstance(schema[0], str))
):
raise NotImplementedError("Only schema of either list or string of list supported")
if not data:
raise ValueError("Must provide data to create into a DataFrame")
column_mapping: t.Dict[str, t.Optional[str]]
if schema is not None:
column_mapping = get_column_mapping_from_schema_input(schema)
elif isinstance(data[0], dict):
column_mapping = {col_name.strip(): None for col_name in data[0]}
else:
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
data_expressions = [
exp.Tuple(
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
)
for row in data
]
sel_columns = [
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
for name, data_type in column_mapping.items()
]
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
expressions=[
exp.Subquery(
this=exp.Values(expressions=data_expressions),
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
)
]
),
}
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read="spark")
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
elif isinstance(expression, (exp.Create, exp.Insert)):
select_expression = expression.expression.copy()
if isinstance(expression, exp.Insert):
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression)
df = df._convert_leaf_to_cte()
else:
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
return df
@property
def _auto_incrementing_name(self) -> str:
name = f"a{self.incrementing_id}"
self.incrementing_id += 1
return name
@property
def _random_name(self) -> str:
return f"a{str(uuid.uuid4())[:8]}"
@property
def _random_branch_id(self) -> str:
id = self._random_id
self.known_branch_ids.add(id)
return id
@property
def _random_sequence_id(self):
id = self._random_id
self.known_sequence_ids.add(id)
return id
@property
def _random_id(self) -> str:
id = f"a{str(uuid.uuid4())[:8]}"
self.known_ids.add(id)
return id
@property
def _join_hint_names(self) -> t.Set[str]:
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)

View file

@ -0,0 +1,9 @@
import typing as t
from sqlglot import expressions as exp
def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
if isinstance(node, exp.Identifier) and node in replacement_mapping:
node = node.replace(replacement_mapping[node].copy())
return node

View file

@ -0,0 +1,208 @@
import typing as t
class DataType:
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other: t.Any) -> bool:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other: t.Any) -> bool:
return not self.__eq__(other)
def __str__(self) -> str:
return self.typeName()
@classmethod
def typeName(cls) -> str:
return cls.__name__[:-4].lower()
def simpleString(self) -> str:
return str(self)
def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
return str(self)
class DataTypeWithLength(DataType):
def __init__(self, length: int):
self.length = length
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.length})"
def __str__(self) -> str:
return f"{self.typeName()}({self.length})"
class StringType(DataType):
pass
class CharType(DataTypeWithLength):
pass
class VarcharType(DataTypeWithLength):
pass
class BinaryType(DataType):
pass
class BooleanType(DataType):
pass
class DateType(DataType):
pass
class TimestampType(DataType):
pass
class TimestampNTZType(DataType):
@classmethod
def typeName(cls) -> str:
return "timestamp_ntz"
class DecimalType(DataType):
def __init__(self, precision: int = 10, scale: int = 0):
self.precision = precision
self.scale = scale
def simpleString(self) -> str:
return f"decimal({self.precision}, {self.scale})"
def jsonValue(self) -> str:
return f"decimal({self.precision}, {self.scale})"
def __repr__(self) -> str:
return f"DecimalType({self.precision}, {self.scale})"
class DoubleType(DataType):
pass
class FloatType(DataType):
pass
class ByteType(DataType):
def __str__(self) -> str:
return "tinyint"
class IntegerType(DataType):
def __str__(self) -> str:
return "int"
class LongType(DataType):
def __str__(self) -> str:
return "bigint"
class ShortType(DataType):
def __str__(self) -> str:
return "smallint"
class ArrayType(DataType):
def __init__(self, elementType: DataType, containsNull: bool = True):
self.elementType = elementType
self.containsNull = containsNull
def __repr__(self) -> str:
return f"ArrayType({self.elementType, str(self.containsNull)}"
def simpleString(self) -> str:
return f"array<{self.elementType.simpleString()}>"
def jsonValue(self) -> t.Dict[str, t.Any]:
return {
"type": self.typeName(),
"elementType": self.elementType.jsonValue(),
"containsNull": self.containsNull,
}
class MapType(DataType):
def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
def __repr__(self) -> str:
return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
def simpleString(self) -> str:
return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
def jsonValue(self) -> t.Dict[str, t.Any]:
return {
"type": self.typeName(),
"keyType": self.keyType.jsonValue(),
"valueType": self.valueType.jsonValue(),
"valueContainsNull": self.valueContainsNull,
}
class StructField(DataType):
def __init__(
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
):
self.name = name
self.dataType = dataType
self.nullable = nullable
self.metadata = metadata or {}
def __repr__(self) -> str:
return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
def simpleString(self) -> str:
return f"{self.name}:{self.dataType.simpleString()}"
def jsonValue(self) -> t.Dict[str, t.Any]:
return {
"name": self.name,
"type": self.dataType.jsonValue(),
"nullable": self.nullable,
"metadata": self.metadata,
}
class StructType(DataType):
def __init__(self, fields: t.Optional[t.List[StructField]] = None):
if not fields:
self.fields = []
self.names = []
else:
self.fields = fields
self.names = [f.name for f in fields]
def __iter__(self) -> t.Iterator[StructField]:
return iter(self.fields)
def __len__(self) -> int:
return len(self.fields)
def __repr__(self) -> str:
return f"StructType({', '.join(str(field) for field in self)})"
def simpleString(self) -> str:
return f"struct<{', '.join(x.simpleString() for x in self)}>"
def jsonValue(self) -> t.Dict[str, t.Any]:
return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
def fieldNames(self) -> t.List[str]:
return list(self.names)

View file

@ -0,0 +1,32 @@
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import types
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import SchemaInput
def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]:
if isinstance(schema, dict):
return schema
elif isinstance(schema, str):
col_name_type_strs = [x.strip() for x in schema.split(",")]
return {
name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
for name_type_str in col_name_type_strs
}
elif isinstance(schema, types.StructType):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
return {x.strip(): None for x in schema} # type: ignore
def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
if not expression.args.get("joins"):
return []
left_table = expression.args["from"].args["expressions"][0]
other_tables = [join.this for join in expression.args["joins"]]
return [left_table] + other_tables

View file

@ -0,0 +1,117 @@
from __future__ import annotations
import sys
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.helper import flatten
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrName
class Window:
_JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
_JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
_PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
_FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
unboundedPreceding: int = _JAVA_MIN_LONG
unboundedFollowing: int = _JAVA_MAX_LONG
currentRow: int = 0
@classmethod
def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rangeBetween(start, end)
class WindowSpec:
def __init__(self, expression: exp.Expression = exp.Window()):
self.expression = expression
def copy(self):
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
return self.expression.sql(dialect="spark", **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
partition_by_expressions = window_spec.expression.args.get("partition_by", [])
partition_by_expressions.extend(expressions)
window_spec.expression.set("partition_by", partition_by_expressions)
return window_spec
def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
if window_spec.expression.args.get("order") is None:
window_spec.expression.set("order", exp.Order(expressions=[]))
order_by = window_spec.expression.args["order"].expressions
order_by.extend(expressions)
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"start_side": "PRECEDING",
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
},
}
if end == Window.currentRow:
kwargs["end"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"end_side": "FOLLOWING",
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
},
}
return kwargs
def rowsBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
)
return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
)
return window_spec

View file

@ -78,6 +78,16 @@ def _create_sql(self, expression):
class BigQuery(Dialect):
unnest_column_only = True
time_mapping = {
"%M": "%-M",
"%d": "%-d",
"%m": "%-m",
"%y": "%-y",
"%H": "%-H",
"%I": "%-I",
"%S": "%-S",
"%j": "%-j",
}
class Tokenizer(Tokenizer):
QUOTES = [
@ -113,6 +123,7 @@ class BigQuery(Dialect):
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
}
NO_PAREN_FUNCTIONS = {
@ -137,6 +148,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),

View file

@ -2,7 +2,7 @@ from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.helper import flatten, list_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@ -67,6 +67,11 @@ class _Dialect(type):
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
return klass
@ -176,11 +181,7 @@ class Dialect(metaclass=_Dialect):
def rename_func(name):
def _rename(self, expression):
args = (
expression.expressions
if isinstance(expression, exp.Func) and expression.is_var_len_args
else expression.args.values()
)
args = flatten(expression.args.values())
return f"{name}({self.format_args(*args)})"
return _rename

View file

@ -121,6 +121,9 @@ class Hive(Dialect):
"ss": "%S",
"s": "%-S",
"S": "%f",
"a": "%p",
"DD": "%j",
"D": "%-j",
}
date_format = "'yyyy-MM-dd'"
@ -200,6 +203,7 @@ class Hive(Dialect):
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
exp.With: no_recursive_cte_sql,

View file

@ -97,6 +97,8 @@ class MySQL(Dialect):
"%s": "%S",
"%S": "%S",
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
}
class Tokenizer(Tokenizer):
@ -145,6 +147,9 @@ class MySQL(Dialect):
"_TIS620": TokenType.INTRODUCER,
"_UCS2": TokenType.INTRODUCER,
"_UJIS": TokenType.INTRODUCER,
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
"N": TokenType.INTRODUCER,
"n": TokenType.INTRODUCER,
"_UTF8": TokenType.INTRODUCER,
"_UTF16": TokenType.INTRODUCER,
"_UTF16LE": TokenType.INTRODUCER,

View file

@ -80,17 +80,12 @@ class Oracle(Dialect):
sep="",
)
def alias_sql(self, expression):
if isinstance(expression.this, exp.Table):
to_sql = self.sql(expression, "alias")
# oracle does not allow "AS" between table and alias
to_sql = f" {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
return super().alias_sql(expression)
def offset_sql(self, expression):
return f"{super().offset_sql(expression)} ROWS"
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,

View file

@ -163,6 +163,7 @@ class Postgres(Dialect):
class Tokenizer(Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
@ -176,6 +177,11 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
class Parser(Parser):
STRICT_CAST = False

View file

@ -172,6 +172,7 @@ class Presto(Dialect):
**transforms.UNALIAS_GROUP,
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",

View file

@ -69,6 +69,35 @@ def _unix_to_time(self, expression):
raise ValueError("Improper scale for timestamp")
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
def _parse_date_part(self):
this = self._parse_var() or self._parse_type()
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
name = this.name.upper()
if name.startswith("EPOCH"):
if name.startswith("EPOCH_MILLISECOND"):
scale = 10**3
elif name.startswith("EPOCH_MICROSECOND"):
scale = 10**6
elif name.startswith("EPOCH_NANOSECOND"):
scale = 10**9
else:
scale = None
ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
to_unix = self.expression(exp.TimeToUnix, this=ts)
if scale:
to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
return to_unix
return self.expression(exp.Extract, this=this, expression=expression)
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
@ -115,7 +144,7 @@ class Snowflake(Dialect):
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
"DATE_PART": lambda self: self._parse_extract(),
"DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
@ -161,9 +190,11 @@ class Snowflake(Dialect):
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",

View file

@ -1,9 +1,5 @@
from sqlglot import exp
from sqlglot.dialects.dialect import (
create_with_partitions_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get
from sqlglot.parser import Parser
@ -98,13 +94,14 @@ class Spark(Hive):
}
TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}},
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateTrunc: rename_func("TRUNC"),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
@ -112,6 +109,8 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
}
WRAP_DERIVED_VALUES = False

View file

@ -32,6 +32,11 @@ class TSQL(Dialect):
}
class Parser(Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
}
def _parse_convert(self):
to = self._parse_types()
self._match(TokenType.COMMA)

View file

@ -19,6 +19,7 @@ ENV = {
"datetime": datetime,
"locals": locals,
"re": re,
"bool": bool,
"float": float,
"int": int,
"str": str,

View file

@ -80,9 +80,10 @@ class PythonExecutor:
source = step.source
if isinstance(source, exp.Expression):
source = source.this.name or source.alias
source = source.name or source.alias
else:
source = step.name
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@ -121,7 +122,7 @@ class PythonExecutor:
source = step.source
alias = source.alias
with csv_reader(source.this) as reader:
with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
@ -308,7 +309,7 @@ def _interval_py(self, expression):
def _like_py(self, expression):
this = self.sql(expression, "this")
expression = self.sql(expression, "expression")
return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})"""
return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
def _ordered_py(self, expression):
@ -330,6 +331,7 @@ class Python(Dialect):
exp.Cast: _cast_py,
exp.Column: _column_py,
exp.EQ: lambda self, e: self.binary(e, "=="),
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
exp.Like: _like_py,

View file

@ -11,6 +11,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_list,
list_get,
split_num_words,
subclasses,
)
@ -108,6 +109,8 @@ class Expression(metaclass=_Expression):
@property
def alias_or_name(self):
if isinstance(self, Null):
return "NULL"
return self.alias or self.name
def __deepcopy__(self, memo):
@ -659,6 +662,10 @@ class HexString(Condition):
pass
class ByteString(Condition):
pass
class Column(Condition):
arg_types = {"this": True, "table": False}
@ -725,7 +732,7 @@ class Constraint(Expression):
class Delete(Expression):
arg_types = {"with": False, "this": True, "where": False}
arg_types = {"with": False, "this": True, "using": False, "where": False}
class Drop(Expression):
@ -1192,6 +1199,7 @@ QUERY_MODIFIERS = {
class Table(Expression):
arg_types = {
"this": True,
"alias": False,
"db": False,
"catalog": False,
"laterals": False,
@ -1323,6 +1331,7 @@ class Select(Subqueryable):
*expressions (str or Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Group`.
If nothing is passed in then a group by is not applied to the expression
append (bool): if `True`, add to any existing expressions.
Otherwise, this flattens all the `Group` expression into a single expression.
dialect (str): the dialect used to parse the input expression.
@ -1332,6 +1341,8 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
if not expressions:
return self if not copy else self.copy()
return _apply_child_list_builder(
*expressions,
instance=self,
@ -2239,6 +2250,11 @@ class ArrayAny(Func):
arg_types = {"this": True, "expression": True}
class ArrayConcat(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
class ArrayContains(Func):
arg_types = {"this": True, "expression": True}
@ -2570,7 +2586,7 @@ class SortArray(Func):
class Split(Func):
arg_types = {"this": True, "expression": True}
arg_types = {"this": True, "expression": True, "limit": False}
# Start may be omitted in the case of postgres
@ -3209,29 +3225,49 @@ def to_identifier(alias, quoted=None):
return identifier
def to_table(sql_path, **kwargs):
def to_table(sql_path: str, **kwargs) -> Table:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
Example:
>>> to_table('catalog.db.table_name').sql()
'catalog.db.table_name'
If a table is passed in then that table is returned.
Args:
sql_path(str): `[catalog].[schema].[table]` string
sql_path(str|Table): `[catalog].[schema].[table]` string
Returns:
Table: A table expression
"""
table_parts = sql_path.split(".")
catalog, db, table_name = [
to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts
]
if sql_path is None or isinstance(sql_path, Table):
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
def to_column(sql_path: str, **kwargs) -> Column:
"""
Create a column from a `[table].[column]` sql path. Schema is optional.
If a column is passed in then that column is returned.
Args:
sql_path: `[table].[column]` string
Returns:
Table: A column expression
"""
if sql_path is None or isinstance(sql_path, Column):
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
return Column(this=column_name, table=table_name, **kwargs)
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
Create an Alias expression.
Expample:
Example:
>>> alias_('foo', 'bar').sql()
'foo AS bar'
@ -3249,7 +3285,16 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
exp = maybe_parse(expression, dialect=dialect, **opts)
alias = to_identifier(alias, quoted=quoted)
alias = TableAlias(this=alias) if table else alias
if table:
expression.set("alias", TableAlias(this=alias))
return expression
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
# for the complete Window expression.
#
# [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls
if "alias" in exp.arg_types and not isinstance(exp, Window):
exp = exp.copy()
@ -3295,7 +3340,7 @@ def column(col, table=None, quoted=None):
)
def table_(table, db=None, catalog=None, quoted=None):
def table_(table, db=None, catalog=None, quoted=None, alias=None):
"""Build a Table.
Args:
@ -3310,6 +3355,7 @@ def table_(table, db=None, catalog=None, quoted=None):
this=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
@ -3453,7 +3499,7 @@ def replace_tables(expression, mapping):
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql()
'SELECT * FROM "c"'
'SELECT * FROM c'
Returns:
The mapped expression
@ -3463,7 +3509,10 @@ def replace_tables(expression, mapping):
if isinstance(node, Table):
new_name = mapping.get(table_name(node))
if new_name:
return table_(*reversed(new_name.split(".")), quoted=True)
return to_table(
new_name,
**{k: v for k, v in node.args.items() if k not in ("this", "db", "catalog")},
)
return node
return expression.transform(_replace_tables)

View file

@ -47,6 +47,8 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
annotations: Whether or not to show annotations in the SQL.
Default: True
"""
TRANSFORMS = {
@ -116,6 +118,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
"_annotations",
)
def __init__(
@ -141,6 +144,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
annotations=True,
):
import sqlglot
@ -169,6 +173,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._annotations = annotations
def generate(self, expression):
"""
@ -275,7 +280,9 @@ class Generator:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
return f"{self.sql(expression, 'expression')} # {expression.name.strip()}"
if self._annotations:
return f"{self.sql(expression, 'expression')} # {expression.name}"
return self.sql(expression, "expression")
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@ -423,8 +430,11 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{where_sql}"
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
def drop_sql(self, expression):
@ -571,7 +581,7 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
def table_sql(self, expression):
def table_sql(self, expression, sep=" AS "):
table = ".".join(
part
for part in [
@ -582,13 +592,20 @@ class Generator:
if part
)
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
return f"{table}{laterals}{joins}{pivots}"
if alias and pivots:
pivots = f"{pivots}{alias}"
alias = ""
return f"{table}{alias}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
else:
@ -1188,7 +1205,7 @@ class Generator:
if isinstance(arg_value, list):
for value in arg_value:
args.append(value)
elif arg_value:
else:
args.append(arg_value)
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"

View file

@ -2,7 +2,9 @@ import inspect
import logging
import re
import sys
import typing as t
from contextlib import contextmanager
from copy import copy
from enum import Enum
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
@ -162,3 +164,54 @@ def find_new_name(taken, base):
i += 1
new = f"{base}_{i}"
return new
def object_to_dict(obj, **kwargs):
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
"""
Perform a split on a value and return N words as a result with None used for words that don't exist.
Args:
value: The value to be split
sep: The value to use to split on
min_num_words: The minimum number of words that are going to be in the result
fill_from_start: Indicates that if None values should be inserted at the start or end of the list
Examples:
>>> split_num_words("db.table", ".", 3)
[None, 'db', 'table']
>>> split_num_words("db.table", ".", 3, fill_from_start=False)
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
"""
words = value.split(sep)
if fill_from_start:
return [None] * (min_num_words - len(words)) + words
return words + [None] * (min_num_words - len(words))
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
"""
Flattens a list that can contain both iterables and non-iterable elements
Examples:
>>> list(flatten([[1, 2], 3]))
[1, 2, 3]
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Args:
values: The value to be flattened
Returns:
Yields non-iterable elements (not including str or byte as iterable)
"""
for value in values:
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
yield from flatten(value)
else:
yield value

View file

@ -1,2 +1 @@
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.schema import Schema

View file

@ -1,7 +1,7 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):

View file

@ -86,7 +86,7 @@ def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF):
return _eliminate_derived_table(scope, existing_ctes, taken)

View file

@ -12,18 +12,16 @@ def isolate_table_selects(expression):
if not isinstance(source, exp.Table):
continue
if not isinstance(source.parent, exp.Alias):
if not source.alias:
raise OptimizeError("Tables require an alias. Run qualify_tables optimization.")
parent = source.parent
parent.replace(
source.replace(
exp.select("*")
.from_(
alias(source, source.name or parent.alias, table=True),
alias(source.copy(), source.name or source.alias, table=True),
copy=False,
)
.subquery(parent.alias, copy=False)
.subquery(source.alias, copy=False)
)
return expression

View file

@ -70,15 +70,10 @@ def merge_ctes(expression, leave_tables_isolated=False):
inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
alias = table.alias_or_name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_where(outer_scope, inner_scope, from_or_join)
@ -179,8 +174,8 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
if isinstance(source, exp.Subquery):
source.set("alias", exp.TableAlias(this=new_alias))
elif isinstance(source, exp.Table) and isinstance(source.parent, exp.Alias):
source.parent.set("alias", new_alias)
elif isinstance(source, exp.Table) and source.alias:
source.set("alias", new_alias)
elif isinstance(source, exp.Table):
source.replace(exp.alias_(source.copy(), new_alias))
@ -206,8 +201,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
tables = join_hint.find_all(exp.Table)
for table in tables:
if table.alias_or_name == node_to_replace.alias_or_name:
new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery
table.set("this", exp.to_identifier(new_table.alias_or_name))
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])

View file

@ -1,3 +1,4 @@
import sqlglot
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@ -43,6 +44,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
rules (list): sequence of optimizer rules to use
@ -50,13 +52,12 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
Returns:
sqlglot.Expression: optimized expression
"""
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
expression = expression.copy()
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
expression = rule(expression, **rule_kwargs)
return expression

View file

@ -6,6 +6,9 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# SELECTION TO USE IF SELECTION LIST IS EMPTY
DEFAULT_SELECTION = alias("1", "_")
def pushdown_projections(expression):
"""
@ -25,7 +28,8 @@ def pushdown_projections(expression):
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
left_union = None
right_union = None
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
@ -37,12 +41,16 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections
left_union, right_union = scope.union_scopes
referenced_columns[left_union] = parent_selections
referenced_columns[right_union] = parent_selections
if isinstance(scope.expression, exp.Select):
_remove_unused_selections(scope, parent_selections)
if isinstance(scope.expression, exp.Select) and scope != right_union:
removed_indexes = _remove_unused_selections(scope, parent_selections)
# The left union is used for column names to select and if we remove columns from the left
# we need to also remove those same columns in the right that were at the same position
if scope is left_union:
_remove_indexed_selections(right_union, removed_indexes)
# Group columns by source name
selects = defaultdict(set)
@ -61,6 +69,7 @@ def pushdown_projections(expression):
def _remove_unused_selections(scope, parent_selections):
removed_indexes = []
order = scope.expression.args.get("order")
if order:
@ -70,16 +79,26 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
for selection in scope.selects:
for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
else:
removed_indexes.append(i)
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(alias("1", "_"))
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)
return removed_indexes
def _remove_indexed_selections(scope, indexes_to_remove):
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)

View file

@ -2,8 +2,8 @@ import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
def qualify_columns(expression, schema):
@ -48,7 +48,7 @@ def _pop_table_column_aliases(derived_tables):
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, exp.UDTF):
if isinstance(derived_table.unnest(), exp.UDTF):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
@ -211,6 +211,22 @@ def _qualify_columns(scope, resolver):
if column_table:
column.set("table", exp.to_identifier(column_table))
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
column_table = column.table
column_name = column.name
if column_table or column.parent is ordered or column_name not in resolver.all_columns:
continue
column_table = resolver.get_table(column_name)
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
"""Expand stars to lists of column selections"""
@ -346,6 +362,11 @@ class _Resolver:
except Exception as e:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
values_alias = source.expression.parent
if hasattr(values_alias, "alias_column_names"):
return values_alias.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects

View file

@ -40,7 +40,7 @@ def qualify_tables(expression, db=None, catalog=None):
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
if not isinstance(source.parent, exp.Alias):
if not source.alias:
source.replace(
alias(
source.copy(),

View file

@ -1,180 +0,0 @@
import abc
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import csv_reader
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def column_names(self, table, only_visible=False):
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
only_visible (bool): Whether to include invisible columns
Returns:
list[str]: list of column names
"""
@abc.abstractmethod
def get_column_type(self, table, column):
"""
Get the exp.DataType type of a column in the schema.
Args:
table (sqlglot.expressions.Table): The source table.
column (sqlglot.expressions.Column): The target column.
Returns:
sqlglot.expressions.DataType.Type: The resulting column type.
"""
class MappingSchema(Schema):
"""
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
"""
def __init__(self, schema, visible=None, dialect=None):
self.schema = schema
self.visible = visible
self.dialect = dialect
self._type_mapping_cache = {}
depth = _dict_depth(schema)
if not depth: # {}
self.supported_table_args = []
elif depth == 2: # {table: {col: type}}
self.supported_table_args = ("this",)
elif depth == 3: # {db: {table: {col: type}}}
self.supported_table_args = ("db", "this")
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
self.supported_table_args = ("catalog", "db", "this")
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def column_names(self, table, only_visible=False):
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
args = tuple(table.text(p) for p in self.supported_table_args)
for forbidden in self.forbidden_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
return columns
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
return [col for col in columns if col in visible]
def get_column_type(self, table, column):
try:
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column.sql()}")
def _convert_type(self, schema_type):
"""
Convert a type represented as a string to the corresponding exp.DataType.Type object.
Args:
schema_type (str): The type we want to convert.
Returns:
sqlglot.expressions.DataType.Type: The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
self._type_mapping_cache[schema_type] = exp.maybe_parse(
schema_type, into=exp.DataType, dialect=self.dialect
).this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def ensure_schema(schema):
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def fs_get(table):
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path):
"""
Get a value for a nested dictionary.
Args:
d (dict): dictionary
*path (tuple[str, str]): tuples of (name, key)
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
"""
for name, key in path:
d = d.get(key)
if d is None:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}")
return d
def _dict_depth(d):
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -257,12 +257,7 @@ class Scope:
referenced_names = []
for table in self.tables:
referenced_names.append(
(
table.parent.alias if isinstance(table.parent, exp.Alias) else table.name,
table,
)
)
referenced_names.append((table.alias_or_name, table))
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
@ -538,8 +533,8 @@ def _add_table_sources(scope):
for table in scope.tables:
table_name = table.name
if isinstance(table.parent, exp.Alias):
source_name = table.parent.alias
if table.alias:
source_name = table.alias
else:
source_name = table_name

View file

@ -329,6 +329,7 @@ class Parser:
exp.DataType: lambda self: self._parse_types(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
exp.Identifier: lambda self: self._parse_id_var(),
exp.Lateral: lambda self: self._parse_lateral(),
exp.Join: lambda self: self._parse_join(),
exp.Order: lambda self: self._parse_order(),
@ -371,11 +372,8 @@ class Parser:
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self.expression(
exp.Introducer,
this=token.text,
expression=self._parse_var_or_string(),
),
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
}
RANGE_PARSERS = {
@ -500,7 +498,7 @@ class Parser:
max_errors=3,
null_ordering=None,
):
self.error_level = error_level or ErrorLevel.RAISE
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
@ -928,6 +926,7 @@ class Parser:
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
where=self._parse_where(),
)
@ -1148,7 +1147,7 @@ class Parser:
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
return self.expression(exp.Annotation, this=self._prev.text, expression=expression)
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
return expression
@ -1277,7 +1276,7 @@ class Parser:
alias = self._parse_table_alias()
if alias:
this = self.expression(exp.Alias, this=this, alias=alias)
this.set("alias", alias)
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@ -1876,6 +1875,17 @@ class Parser:
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_introducer(self, token):
literal = self._parse_primary()
if literal:
return self.expression(
exp.Introducer,
this=token.text,
expression=literal,
)
return self.expression(exp.Identifier, this=token.text)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()

View file

@ -199,13 +199,14 @@ class Step:
class Scan(Step):
@classmethod
def from_expression(cls, expression, ctes=None):
table = expression.this
table = expression
alias_ = expression.alias
if not alias_:
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
if isinstance(expression, exp.Subquery):
table = expression.this
step = Step.from_expression(table, ctes)
step.name = alias_
return step

297
sqlglot/schema.py Normal file
View file

@ -0,0 +1,297 @@
import abc
from sqlglot import expressions as exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import csv_reader
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def add_table(self, table, column_mapping=None):
"""
Register or update a table. Some implementing classes may require column information to also be provided
Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
"""
@abc.abstractmethod
def column_names(self, table, only_visible=False):
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
only_visible (bool): Whether to include invisible columns
Returns:
list[str]: list of column names
"""
@abc.abstractmethod
def get_column_type(self, table, column):
"""
Get the exp.DataType type of a column in the schema.
Args:
table (sqlglot.expressions.Table): The source table.
column (sqlglot.expressions.Column): The target column.
Returns:
sqlglot.expressions.DataType.Type: The resulting column type.
"""
class MappingSchema(Schema):
"""
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
4. None - Tables will be added later
visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
1. {table: set(*cols)}}
2. {db: {table: set(*cols)}}}
3. {catalog: {db: {table: set(*cols)}}}}
dialect (str): The dialect to be used for custom type mappings.
"""
def __init__(self, schema=None, visible=None, dialect=None):
self.schema = schema or {}
self.visible = visible
self.dialect = dialect
self._type_mapping_cache = {}
self.supported_table_args = []
self.forbidden_table_args = set()
if self.schema:
self._initialize_supported_args()
@classmethod
def from_mapping_schema(cls, mapping_schema):
return MappingSchema(
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
)
def copy(self, **kwargs):
return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
def add_table(self, table, column_mapping=None):
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
"""
table = exp.to_table(table)
self._validate_table(table)
column_mapping = ensure_column_mapping(column_mapping)
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
existing_column_mapping = _nested_get(
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
)
if existing_column_mapping and not column_mapping:
return
_nested_set(
self.schema,
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
column_mapping,
)
self._initialize_supported_args()
def _get_table_args_from_table(self, table):
if table.args.get("catalog") is not None:
return "catalog", "db", "this"
if table.args.get("db") is not None:
return "db", "this"
return ("this",)
def _validate_table(self, table):
if not self.supported_table_args and isinstance(table, exp.Table):
return
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
for expected in self.supported_table_args:
if not table.text(expected):
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
def column_names(self, table, only_visible=False):
table = exp.to_table(table)
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
args = tuple(table.text(p) for p in self.supported_table_args)
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
return columns
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
return [col for col in columns if col in visible]
def get_column_type(self, table, column):
try:
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column.sql()}")
def _convert_type(self, schema_type):
"""
Convert a type represented as a string to the corresponding exp.DataType.Type object.
Args:
schema_type (str): The type we want to convert.
Returns:
sqlglot.expressions.DataType.Type: The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
self._type_mapping_cache[schema_type] = exp.maybe_parse(
schema_type, into=exp.DataType, dialect=self.dialect
).this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def _initialize_supported_args(self):
if not self.supported_table_args:
depth = _dict_depth(self.schema)
all_args = ["this", "db", "catalog"]
if not depth or depth == 1: # {}
self.supported_table_args = []
elif 2 <= depth <= 4:
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def ensure_schema(schema):
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def ensure_column_mapping(mapping):
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
col_name_type_strs = [x.strip() for x in mapping.split(",")]
return {
name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip()
for name_type_str in col_name_type_strs
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
elif mapping is None:
return {}
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def fs_get(table):
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path, raise_on_missing=True):
"""
Get a value for a nested dictionary.
Args:
d (dict): dictionary
*path (tuple[str, str]): tuples of (name, key)
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
Returns:
The value or None if it doesn't exist
"""
for name, key in path:
d = d.get(key)
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}")
return None
return d
def _nested_set(d, keys, value):
"""
In-place set a value for a nested dictionary
Ex:
>>> _nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
d (dict): dictionary
keys (Iterable[str]): ordered iterable of keys that makeup path to value
value (Any): The value to set in the dictionary for the given key path
"""
if not keys:
return
if len(keys) == 1:
d[keys[0]] = value
return
subd = d
for key in keys[:-1]:
if key not in subd:
subd = subd.setdefault(key, {})
else:
subd = subd[key]
subd[keys[-1]] = value
return d
def _dict_depth(d):
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

View file

@ -56,6 +56,7 @@ class TokenType(AutoName):
VAR = auto()
BIT_STRING = auto()
HEX_STRING = auto()
BYTE_STRING = auto()
# types
BOOLEAN = auto()
@ -320,6 +321,7 @@ class _Tokenizer(type):
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
@ -333,6 +335,7 @@ class _Tokenizer(type):
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
}.items()
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
@ -385,6 +388,8 @@ class Tokenizer(metaclass=_Tokenizer):
HEX_STRINGS = []
BYTE_STRINGS = []
IDENTIFIERS = ['"']
ESCAPE = "'"
@ -799,7 +804,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._scan_string(word):
return
if self._scan_numeric_string(word):
if self._scan_formatted_string(word):
return
if self._scan_comment(word):
return
@ -906,7 +911,8 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.STRING, text)
return True
def _scan_numeric_string(self, string_start):
# X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start):
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
token_type = TokenType.HEX_STRING
@ -915,6 +921,10 @@ class Tokenizer(metaclass=_Tokenizer):
delimiters = self._BIT_STRINGS
token_type = TokenType.BIT_STRING
base = 2
elif string_start in self._BYTE_STRINGS:
delimiters = self._BYTE_STRINGS
token_type = TokenType.BYTE_STRING
base = None
else:
return False
@ -922,10 +932,14 @@ class Tokenizer(metaclass=_Tokenizer):
string_end = delimiters.get(string_start)
text = self._extract_string(string_end)
try:
self._add(token_type, f"{int(text, base)}")
except ValueError:
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
if base is None:
self._add(token_type, text)
else:
try:
self._add(token_type, f"{int(text, base)}")
except:
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
return True
def _scan_identifier(self, identifier_end):

View file

View file

View file

@ -0,0 +1,149 @@
import typing as t
import unittest
import warnings
import sqlglot
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
from pyspark.sql import DataFrame as SparkDataFrame
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class DataFrameValidator(unittest.TestCase):
spark = None
sqlglot = None
df_employee = None
df_store = None
df_district = None
spark_employee_schema = None
sqlglot_employee_schema = None
spark_store_schema = None
sqlglot_store_schema = None
spark_district_schema = None
sqlglot_district_schema = None
@classmethod
def setUpClass(cls):
from pyspark import SparkConf
from pyspark.sql import SparkSession, types
from sqlglot.dataframe.sql import types as sqlglotSparkTypes
from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
[
types.StructField("store_id", types.IntegerType(), False),
types.StructField("store_name", types.StringType(), False),
types.StructField("district_id", types.IntegerType(), False),
types.StructField("num_sales", types.IntegerType(), False),
]
)
cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
store_data = [
(1, "Hydra", 1, 37),
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
[
types.StructField("district_id", types.IntegerType(), False),
types.StructField("district_name", types.StringType(), False),
types.StructField("manager_name", types.StringType(), False),
]
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
sqlglot.schema.add_table("district", cls.sqlglot_district_schema)
def setUp(self) -> None:
warnings.filterwarnings("ignore", category=ResourceWarning)
self.df_spark_store = self.df_store.alias("df_store") # type: ignore
self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
self.df_spark_district = self.df_district.alias("df_district") # type: ignore
self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
def compare_spark_with_sqlglot(
self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
def compare_schemas(schema_1, schema_2):
for schema in [schema_1, schema_2]:
for struct_field in schema.fields:
struct_field.metadata = {}
self.assertEqual(schema_1, schema_2)
for statement in df_sqlglot.sql():
actual_df_sqlglot = self.spark.sql(statement) # type: ignore
df_sqlglot_results = actual_df_sqlglot.collect()
df_spark_results = df_spark.collect()
if not skip_schema_compare:
compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
self.assertEqual(df_spark_results, df_sqlglot_results)
if no_empty:
self.assertNotEqual(len(df_spark_results), 0)
self.assertNotEqual(len(df_sqlglot_results), 0)
return df_spark, actual_df_sqlglot
@classmethod
def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,71 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestDataframeFunc(DataFrameValidator):
def test_group_by(self):
df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
F.min(self.df_spark_employee.employee_id)
)
dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
SF.min(self.df_sqlglot_employee.employee_id)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
def test_group_by_where_non_aggregate(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("age") > F.lit(50))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("age") > SF.lit(50))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_group_by_where_aggregate_like_having(self):
df_employee = (
self.df_spark_employee.groupBy(self.df_spark_employee.age)
.agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
.where(F.col("min_employee_id") > F.lit(1))
)
dfs_employee = (
self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
.agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
.where(SF.col("min_employee_id") > SF.lit(1))
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_count(self):
df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
self.compare_spark_with_sqlglot(df, dfs)
def test_mean(self):
df = self.df_spark_employee.groupBy().mean("age", "store_id")
dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_avg(self):
df = self.df_spark_employee.groupBy("age").avg("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_max(self):
df = self.df_spark_employee.groupBy("age").max("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_min(self):
df = self.df_spark_employee.groupBy("age").min("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
self.compare_spark_with_sqlglot(df, dfs)
def test_sum(self):
df = self.df_spark_employee.groupBy("age").sum("store_id")
dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
self.compare_spark_with_sqlglot(df, dfs)

View file

@ -0,0 +1,28 @@
from pyspark.sql import functions as F
from sqlglot.dataframe.sql import functions as SF
from tests.dataframe.integration.dataframe_validator import DataFrameValidator
class TestSessionFunc(DataFrameValidator):
def test_sql_simple_select(self):
query = "SELECT fname, lname FROM employee"
df = self.spark.sql(query)
dfs = self.sqlglot.sql(query)
self.compare_spark_with_sqlglot(df, dfs)
def test_sql_with_join(self):
query = """
SELECT
e.employee_id
, s.store_id
FROM
employee e
INNER JOIN
store s
ON
e.store_id = s.store_id
"""
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

View file

@ -0,0 +1,35 @@
import typing as t
import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
class DataFrameSQLValidator(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession()
self.employee_schema = types.StructType(
[
types.StructField("employee_id", types.IntegerType(), False),
types.StructField("fname", types.StringType(), False),
types.StructField("lname", types.StringType(), False),
types.StructField("age", types.IntegerType(), False),
types.StructField("store_id", types.IntegerType(), False),
]
)
employee_data = [
(1, "Jack", "Shephard", 37, 1),
(2, "John", "Locke", 65, 1),
(3, "Kate", "Austen", 37, 2),
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -0,0 +1,167 @@
import datetime
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window
class TestDataframeColumn(unittest.TestCase):
def test_eq(self):
self.assertEqual("cola = 1", (F.col("cola") == 1).sql())
def test_neq(self):
self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())
def test_gt(self):
self.assertEqual("cola > 1", (F.col("cola") > 1).sql())
def test_lt(self):
self.assertEqual("cola < 1", (F.col("cola") < 1).sql())
def test_le(self):
self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())
def test_ge(self):
self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
)
def test_mod(self):
self.assertEqual("cola % 2", (F.col("cola") % 2).sql())
def test_add(self):
self.assertEqual("cola + 1", (F.col("cola") + 1).sql())
def test_sub(self):
self.assertEqual("cola - 1", (F.col("cola") - 1).sql())
def test_mul(self):
self.assertEqual("cola * 2", (F.col("cola") * 2).sql())
def test_div(self):
self.assertEqual("cola / 2", (F.col("cola") / 2).sql())
def test_radd(self):
self.assertEqual("1 + cola", (1 + F.col("cola")).sql())
def test_rsub(self):
self.assertEqual("1 - cola", (1 - F.col("cola")).sql())
def test_rmul(self):
self.assertEqual("1 * cola", (1 * F.col("cola")).sql())
def test_rdiv(self):
self.assertEqual("1 / cola", (1 / F.col("cola")).sql())
def test_pow(self):
self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())
def test_rpow(self):
self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())
def test_invert(self):
self.assertEqual("NOT cola", (~F.col("cola")).sql())
def test_startswith(self):
self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())
def test_endswith(self):
self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())
def test_rlike(self):
self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())
def test_like(self):
self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())
def test_ilike(self):
self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())
def test_substring(self):
self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())
def test_isin(self):
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())
def test_asc(self):
self.assertEqual("cola", F.col("cola").asc().sql())
def test_desc(self):
self.assertEqual("cola DESC", F.col("cola").desc().sql())
def test_asc_nulls_first(self):
self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())
def test_asc_nulls_last(self):
self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())
def test_desc_nulls_first(self):
self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())
def test_desc_nulls_last(self):
self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
)
def test_is_null(self):
self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())
def test_is_not_null(self):
self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())
def test_cast(self):
self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())
def test_alias(self):
self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())
def test_between(self):
self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
self.assertEqual(
"cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')",
F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
)
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
)
def test_over(self):
over_rows = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_rows.sql(),
)
over_range = F.sum("cola").over(
Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
)
self.assertEqual(
"SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
over_range.sql(),
)

View file

@ -0,0 +1,39 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.dataframe import DataFrame
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
"DROP VIEW IF EXISTS t11623",
"CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
]
self.compare_sql(df, expected_statements)

View file

@ -0,0 +1,86 @@
from unittest import mock
import sqlglot
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)
def test_saveAsTable_format(self):
with self.assertRaises(NotImplementedError):
self.df_employee.write.saveAsTable("table_name", format="parquet").sql(pretty=False)[0]
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
"DROP VIEW IF EXISTS t35612",
"CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
"CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
]
self.compare_sql(df, expected_statements)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
from unittest import mock
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.schema import MappingSchema
from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = (
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
schema = types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
schema = types.StructType(
[
types.StructField(
"cola",
types.StructType(
[
types.StructField("sub_cola", types.IntegerType()),
types.StructField("sub_colb", types.StringType()),
]
),
)
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_select_only(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_with_aggs(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
result = df.sql(pretty=False, optimize=False)[0]
self.assertIn("SELECT cola, colb FROM table", result)
self.assertIn("SUM(colb)", result)
self.assertIn("GROUP BY cola", result)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_create(self):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = (
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
spark = SparkSession()
self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark)

View file

@ -0,0 +1,70 @@
import unittest
from sqlglot.dataframe.sql import types
class TestDataframeTypes(unittest.TestCase):
def test_string(self):
self.assertEqual("string", types.StringType().simpleString())
def test_char(self):
self.assertEqual("char(100)", types.CharType(100).simpleString())
def test_varchar(self):
self.assertEqual("varchar(65)", types.VarcharType(65).simpleString())
def test_binary(self):
self.assertEqual("binary", types.BinaryType().simpleString())
def test_boolean(self):
self.assertEqual("boolean", types.BooleanType().simpleString())
def test_date(self):
self.assertEqual("date", types.DateType().simpleString())
def test_timestamp(self):
self.assertEqual("timestamp", types.TimestampType().simpleString())
def test_timestamp_ntz(self):
self.assertEqual("timestamp_ntz", types.TimestampNTZType().simpleString())
def test_decimal(self):
self.assertEqual("decimal(10, 3)", types.DecimalType(10, 3).simpleString())
def test_double(self):
self.assertEqual("double", types.DoubleType().simpleString())
def test_float(self):
self.assertEqual("float", types.FloatType().simpleString())
def test_byte(self):
self.assertEqual("tinyint", types.ByteType().simpleString())
def test_integer(self):
self.assertEqual("int", types.IntegerType().simpleString())
def test_long(self):
self.assertEqual("bigint", types.LongType().simpleString())
def test_short(self):
self.assertEqual("smallint", types.ShortType().simpleString())
def test_array(self):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
def test_struct_type(self):
self.assertEqual(
"struct<cola:int, colb:string>",
types.StructType(
[
types.StructField("cola", types.IntegerType()),
types.StructField("colb", types.StringType()),
]
).simpleString(),
)

View file

@ -0,0 +1,60 @@
import unittest
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window, WindowSpec
class TestDataframeWindow(unittest.TestCase):
def test_window_spec_partition_by(self):
partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_spec_order_by(self):
order_by = WindowSpec().orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_spec_rows_between(self):
rows_between = WindowSpec().rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_spec_range_between(self):
range_between = WindowSpec().rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_partition_by(self):
partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())
def test_window_order_by(self):
order_by = Window.orderBy("cola", "colb")
self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())
def test_window_rows_between(self):
rows_between = Window.rowsBetween(3, 5)
self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())
def test_window_range_between(self):
range_between = Window.rangeBetween(3, 5)
self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
)

View file

@ -694,29 +694,6 @@ class TestDialect(Validator):
},
)
# https://dev.mysql.com/doc/refman/8.0/en/join.html
# https://www.postgresql.org/docs/current/queries-table-expressions.html
def test_joined_tables(self):
self.validate_identity("SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)")
self.validate_identity("SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)")
self.validate_identity("SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)")
self.validate_identity("SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)")
self.validate_all(
"SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
write={
"postgres": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
"mysql": "SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)",
},
)
self.validate_all(
"SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
write={
"postgres": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
"mysql": "SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)",
},
)
def test_lateral_subquery(self):
self.validate_identity(
"SELECT art FROM tbl1 INNER JOIN LATERAL (SELECT art FROM tbl2) AS tbl2 ON tbl1.art = tbl2.art"
@ -856,7 +833,7 @@ class TestDialect(Validator):
"postgres": "x ILIKE '%y'",
"presto": "LOWER(x) LIKE '%y'",
"snowflake": "x ILIKE '%y'",
"spark": "LOWER(x) LIKE '%y'",
"spark": "x ILIKE '%y'",
"sqlite": "LOWER(x) LIKE '%y'",
"starrocks": "LOWER(x) LIKE '%y'",
"trino": "LOWER(x) LIKE '%y'",

View file

@ -48,7 +48,7 @@ class TestDuckDB(Validator):
self.validate_all(
"STRPTIME(x, '%y-%-m')",
write={
"bigquery": "STR_TO_TIME(x, '%y-%-m')",
"bigquery": "PARSE_TIMESTAMP('%y-%m', x)",
"duckdb": "STRPTIME(x, '%y-%-m')",
"presto": "DATE_PARSE(x, '%y-%c')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)",
@ -63,6 +63,16 @@ class TestDuckDB(Validator):
"hive": "CAST(x AS TIMESTAMP)",
},
)
self.validate_all(
"STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
write={
"bigquery": "PARSE_TIMESTAMP('%m/%d/%y %I:%M %p', x)",
"duckdb": "STRPTIME(x, '%-m/%-d/%y %-I:%M %p')",
"presto": "DATE_PARSE(x, '%c/%e/%y %l:%i %p')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'M/d/yy h:mm a')) AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')",
},
)
def test_duckdb(self):
self.validate_all(
@ -268,6 +278,17 @@ class TestDuckDB(Validator):
"spark": "MONTH('2021-03-01')",
},
)
self.validate_all(
"ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
write={
"duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
"presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
"hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
"bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
},
)
with self.assertRaises(UnsupportedError):
transpile(

View file

@ -31,6 +31,24 @@ class TestMySQL(Validator):
"mysql": "_utf8mb4 'hola'",
},
)
self.validate_all(
"N 'some text'",
read={
"mysql": "N'some text'",
},
write={
"mysql": "N 'some text'",
},
)
self.validate_all(
"_latin1 x'4D7953514C'",
read={
"mysql": "_latin1 X'4D7953514C'",
},
write={
"mysql": "_latin1 x'4D7953514C'",
},
)
def test_hexadecimal_literal(self):
self.validate_all(

View file

@ -69,6 +69,8 @@ class TestPostgres(Validator):
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
@ -204,3 +206,11 @@ class TestPostgres(Validator):
"""'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""",
write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""},
)
self.validate_all(
"SELECT $$a$$",
write={"postgres": "SELECT 'a'"},
)
self.validate_all(
"SELECT $$Dianne's horse$$",
write={"postgres": "SELECT 'Dianne''s horse'"},
)

View file

@ -321,7 +321,7 @@ class TestPresto(Validator):
"duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo",
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
},
)
self.validate_all(
@ -329,7 +329,7 @@ class TestPresto(Validator):
write={
"presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo",
"hive": UnsupportedError,
"spark": UnsupportedError,
"spark": "SELECT APPROX_COUNT_DISTINCT(a, 0.1) FROM foo",
},
)
self.validate_all(

View file

@ -65,7 +65,7 @@ class TestSnowflake(Validator):
self.validate_all(
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
write={
"bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')",
},
@ -73,16 +73,17 @@ class TestSnowflake(Validator):
self.validate_all(
"SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
read={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
"duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
},
write={
"bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')",
"bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')",
"snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')",
"spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')",
},
)
self.validate_all(
"SELECT IFF(TRUE, 'true', 'false')",
write={
@ -240,11 +241,25 @@ class TestSnowflake(Validator):
},
)
self.validate_all(
"SELECT DATE_PART(month FROM a::DATETIME)",
"SELECT DATE_PART(month, a::DATETIME)",
write={
"snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_second, foo) as ddate from table_name",
write={
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name",
write={
"snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
},
)
def test_semi_structured_types(self):
self.validate_identity("SELECT CAST(a AS VARIANT)")

View file

@ -45,3 +45,29 @@ class TestTSQL(Validator):
"tsql": "CAST(x AS DATETIME2)",
},
)
def test_charindex(self):
self.validate_all(
"CHARINDEX(x, y, 9)",
write={
"spark": "LOCATE(x, y, 9)",
},
)
self.validate_all(
"CHARINDEX(x, y)",
write={
"spark": "LOCATE(x, y)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring', 3)",
write={
"spark": "LOCATE('sub', 'testsubstring', 3)",
},
)
self.validate_all(
"CHARINDEX('sub', 'testsubstring')",
write={
"spark": "LOCATE('sub', 'testsubstring')",
},
)

View file

@ -513,6 +513,8 @@ ALTER TYPE electronic_mail RENAME TO email
ANALYZE a.y
DELETE FROM x WHERE y > 1
DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
@ -563,3 +565,8 @@ WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
SELECT ((SELECT 1) + 1)
SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES
SELECT * FROM (table1 AS t1 LEFT JOIN table2 AS t2 ON 1 = 1)
SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)

View file

@ -287,3 +287,27 @@ SELECT
FROM
t1;
SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x;
# title: Values Test
# dialect: spark
WITH t1 AS (
SELECT
a1.cola
FROM
VALUES (1) AS a1(cola)
), t2 AS (
SELECT
a2.cola
FROM
VALUES (1) AS a2(cola)
)
SELECT /*+ BROADCAST(t2) */
t1.cola,
t2.cola,
FROM
t1
JOIN
t2
ON
t1.cola = t2.cola;
SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola;

View file

@ -33,3 +33,6 @@ SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.
with t1 as (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x) SELECT t1.a, t1.b FROM t1 WHERE row_num = 1;
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;
WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a;

View file

@ -22,6 +22,9 @@ SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_
SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0";
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1;
WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1;
SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x);
SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";

View file

@ -72,6 +72,9 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
SELECT a FROM x ORDER BY b;
SELECT x.a AS a FROM x AS x ORDER BY x.b;
SELECT SUM(a) AS a FROM x ORDER BY SUM(a);
SELECT SUM(x.a) AS a FROM x AS x ORDER BY SUM(x.a);
# dialect: bigquery
SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1;
SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1;

View file

@ -53,6 +53,8 @@ def string_to_bool(string):
return string and string.lower() in ("true", "1")
SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower())
TPCH_SCHEMA = {
"lineitem": {
"l_orderkey": "uint64",

View file

@ -7,11 +7,17 @@ from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
from sqlglot.executor import execute
from sqlglot.executor.python import Python
from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
TPCH_SCHEMA,
load_sql_fixture_pairs,
)
DIR = FIXTURES_DIR + "/optimizer/tpc-h/"
@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
class TestExecutor(unittest.TestCase):
@classmethod
def setUpClass(cls):

View file

@ -123,13 +123,16 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
def test_table(self):
self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table))
def test_replace_tables(self):
self.assertEqual(
exp.replace_tables(
parse_one("select * from a join b join c.a join d.a join e.a"),
parse_one("select * from a AS a join b join c.a join d.a join e.a"),
{"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
).sql(),
'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
"SELECT * FROM a1 AS a JOIN b.a JOIN c.a2 JOIN d2 JOIN e.a",
)
def test_named_selects(self):
@ -495,11 +498,15 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self):
expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"],
)
self.assertEqual(expression.sql(), sql)
self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(expression.sql(annotations=False), "SELECT a, b AS B, c, d AS D")
def test_to_table(self):
table_only = exp.to_table("table_name")
@ -514,6 +521,18 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(catalog_db_and_table.name, "table_name")
self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db"))
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)
def test_to_column(self):
column_only = exp.to_column("column_name")
self.assertEqual(column_only.name, "column_name")
self.assertIsNone(column_only.args.get("table"))
table_and_column = exp.to_column("table_name.column_name")
self.assertEqual(table_and_column.name, "column_name")
self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name"))
with self.assertRaises(ValueError):
exp.to_column(1)
def test_union(self):
expression = parse_one("SELECT cola, colb UNION SELECT colx, coly")

View file

@ -5,11 +5,11 @@ import duckdb
from pandas.testing import assert_frame_equal
import sqlglot
from sqlglot import exp, optimizer, parse_one, table
from sqlglot import exp, optimizer, parse_one
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema
from tests.helpers import (
TPCH_SCHEMA,
load_sql_fixture_pairs,
@ -29,19 +29,19 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (2, 2);
INSERT INTO x VALUES (3, 3);
INSERT INTO x VALUES (null, null);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (2, 2);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (null, null);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (3, 3);
INSERT INTO y VALUES (4, 4);
@ -80,8 +80,8 @@ class TestOptimizer(unittest.TestCase):
with self.subTest(title):
self.assertEqual(
optimized.sql(pretty=pretty, dialect=dialect),
expected,
optimized.sql(pretty=pretty, dialect=dialect),
)
should_execute = meta.get("execute")
@ -223,85 +223,6 @@ class TestOptimizer(unittest.TestCase):
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(
schema.column_names(
table(
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
schema = ensure_schema(
MappingSchema(
{
"x": {
"a": "uint64",
}
}
)
)
self.assertEqual(schema.column_names(table("x")), ["a"])
with self.assertRaises(OptimizeError):
ensure_schema({})
def test_file_schema(self):
expression = parse_one(
"""
@ -327,6 +248,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
SELECT x.b FROM x
), r AS (
SELECT y.b FROM y
), z as (
SELECT cola, colb FROM (VALUES(1, 'test')) AS tab(cola, colb)
)
SELECT
r.b,
@ -340,19 +263,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = parse_one(sql)
for scopes in traverse_scope(expression), list(build_scope(expression).traverse()):
self.assertEqual(len(scopes), 5)
self.assertEqual(len(scopes), 7)
self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[3].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual(
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
)
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
self.assertEqual(scopes[6].expression.sql(), parse_one(sql).sql())
self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
self.assertEqual(len(scopes[4].columns), 6)
self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
self.assertEqual(scopes[4].source_columns("q"), [])
self.assertEqual(len(scopes[4].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})
self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"})
self.assertEqual(len(scopes[6].columns), 6)
self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"})
self.assertEqual(scopes[6].source_columns("q"), [])
self.assertEqual(len(scopes[6].source_columns("r")), 2)
self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"})
self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"})
self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b")

View file

@ -81,7 +81,7 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint)
self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint)
default = Parser()
default = Parser(error_level=ErrorLevel.RAISE)
self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint)
default.expression(exp.Hint, y="")
default.expression(exp.Hint)
@ -139,12 +139,12 @@ class TestParser(unittest.TestCase):
)
assert expression.expressions[0].name == "annotation1"
assert expression.expressions[1].name == "annotation2:testing "
assert expression.expressions[1].name == "annotation2:testing"
assert expression.expressions[2].name == "test#annotation"
assert expression.expressions[3].name == "annotation3"
assert expression.expressions[4].name == "annotation4"
assert expression.expressions[5].name == ""
assert expression.expressions[6].name == " space"
assert expression.expressions[6].name == "space"
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")

290
tests/test_schema.py Normal file
View file

@ -0,0 +1,290 @@
import unittest
from sqlglot import table
from sqlglot.dataframe.sql import types as df_types
from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase):
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(
schema.column_names(
table(
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y"), {"b": "string"})
schema_with_y = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
self.assertEqual(schema.schema, schema_with_y)
new_schema = schema.copy()
new_schema.add_table(table("z"), {"c": "string"})
self.assertEqual(schema.schema, schema_with_y)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"z": {
"c": "string",
},
},
)
schema.add_table(table("m"), {"d": "string"})
schema.add_table(table("n"), {"e": "string"})
schema_with_m_n = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
}
self.assertEqual(schema.schema, schema_with_m_n)
new_schema = schema.copy()
new_schema.add_table(table("o"), {"f": "string"})
new_schema.add_table(table("p"), {"g": "string"})
self.assertEqual(schema.schema, schema_with_m_n)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
"o": {
"f": "string",
},
"p": {
"g": "string",
},
},
)
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("y"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y", db="db"), {"b": "string"})
self.assertEqual(
schema.schema,
{
"db": {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
},
)
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("x"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("x", db="db"), {"b": "string"})
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
}
}
},
)
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
}
},
)
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
},
"c2": {
"db2": {
"m": {
"e": "string",
"f": "int",
}
}
},
},
)
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(schema.column_names(table("x")), ["a"])
schema = MappingSchema()
schema.add_table(table("x"), {"a": "string"})
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
}
},
)
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
},
"y": {
"b": "string",
},
},
)
def test_schema_add_table_with_and_without_mapping(self):
schema = MappingSchema()
schema.add_table("test")
self.assertEqual(schema.column_names("test"), [])
schema.add_table("test", {"x": "string"})
self.assertEqual(schema.column_names("test"), ["x"])
schema.add_table("test", {"x": "string", "y": "int"})
self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"])