1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,6 +1,29 @@
Changelog Changelog
========= =========
v10.0.0
------
Changes:
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
- Breaking: renamed list_get to seq_get.
- Breaking: activated mypy type checking for SQLGlot.
- New: Azure Databricks support.
- New: placeholders can now be replaced in an expression.
- New: null safe equal operator (<=>).
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
v9.0.0 v9.0.0
------ ------

View file

@ -14,7 +14,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [Install](#install) * [Install](#install)
* [Documentation](#documentation) * [Documentation](#documentation)
* [Run Tests & Lint](#run-tests-and-lint) * [Run Tests and Lint](#run-tests-and-lint)
* [Examples](#examples) * [Examples](#examples)
* [Formatting and Transpiling](#formatting-and-transpiling) * [Formatting and Transpiling](#formatting-and-transpiling)
* [Metadata](#metadata) * [Metadata](#metadata)
@ -22,7 +22,6 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [Unsupported Errors](#unsupported-errors) * [Unsupported Errors](#unsupported-errors)
* [Build and Modify SQL](#build-and-modify-sql) * [Build and Modify SQL](#build-and-modify-sql)
* [SQL Optimizer](#sql-optimizer) * [SQL Optimizer](#sql-optimizer)
* [SQL Annotations](#sql-annotations)
* [AST Introspection](#ast-introspection) * [AST Introspection](#ast-introspection)
* [AST Diff](#ast-diff) * [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects) * [Custom Dialects](#custom-dialects)
@ -51,7 +50,7 @@ pip3 install -r dev-requirements.txt
## Documentation ## Documentation
SQLGlot's uses [pdocs](https://pdoc.dev/) to serve its API documentation: SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
``` ```
pdoc sqlglot --docformat google pdoc sqlglot --docformat google
@ -121,6 +120,39 @@ LEFT JOIN `baz`
ON `f`.`a` = `baz`.`a` ON `f`.`a` = `baz`.`a`
``` ```
Comments are also preserved in a best-effort basis when transpiling SQL code:
```python
sql = """
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), # comment 3
y -- comment 4
FROM
bar /* comment 5 */,
tbl # comment 6
"""
print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
```
```sql
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6*/
```
### Metadata ### Metadata
You can explore SQL with expression helpers to do things like find columns and tables: You can explore SQL with expression helpers to do things like find columns and tables:
@ -249,17 +281,6 @@ WHERE
"x"."Z" = CAST('2021-02-01' AS DATE) "x"."Z" = CAST('2021-02-01' AS DATE)
``` ```
### SQL Annotations
SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example:
```sql
SELECT
user # primary_key,
country
FROM users
```
### AST Introspection ### AST Introspection
You can see the AST version of the sql by calling `repr`: You can see the AST version of the sql by calling `repr`:

View file

@ -1,15 +1,8 @@
#!/bin/bash -e #!/bin/bash -e
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check' [[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
TARGETS="sqlglot/ tests/"
python -m autoflake -i -r ${RETURN_ERROR_CODE} \ python -m mypy $TARGETS
--expand-star-imports \ python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS
--remove-all-unused-imports \ python -m isort $TARGETS
--ignore-init-module-imports \ python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS
--remove-duplicate-keys \
--remove-unused-variables \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
python -m mypy sqlglot tests
python -m unittest python -m unittest

View file

@ -3,7 +3,7 @@ disallow_untyped_calls = False
no_implicit_optional = True no_implicit_optional = True
[mypy-sqlglot.*] [mypy-sqlglot.*]
ignore_errors = True ignore_errors = False
[mypy-sqlglot.dataframe.*] [mypy-sqlglot.dataframe.*]
ignore_errors = False ignore_errors = False
@ -13,3 +13,16 @@ ignore_errors = True
[mypy-tests.dataframe.*] [mypy-tests.dataframe.*]
ignore_errors = False ignore_errors = False
[autoflake]
in-place = True
expand-star-imports = True
remove-all-unused-imports = True
ignore-init-module-imports = True
remove-duplicate-keys = True
remove-unused-variables = True
quiet = True
[isort]
profile=black
known_first_party=sqlglot

View file

@ -21,6 +21,7 @@ setup(
author_email="toby.mao@gmail.com", author_email="toby.mao@gmail.com",
license="MIT", license="MIT",
packages=find_packages(include=["sqlglot", "sqlglot.*"]), packages=find_packages(include=["sqlglot", "sqlglot.*"]),
package_data={"sqlglot": ["py.typed"]},
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View file

@ -1,5 +1,9 @@
"""## Python SQL parser, transpiler and optimizer.""" """## Python SQL parser, transpiler and optimizer."""
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff from sqlglot.diff import diff
@ -20,51 +24,54 @@ from sqlglot.expressions import (
subquery, subquery,
) )
from sqlglot.expressions import table_ as table from sqlglot.expressions import table_ as table
from sqlglot.expressions import union from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "9.0.6" __version__ = "10.0.1"
pretty = False pretty = False
schema = MappingSchema() schema = MappingSchema()
def parse(sql, read=None, **opts): def parse(
sql: str, read: t.Optional[str | Dialect] = None, **opts
) -> t.List[t.Optional[Expression]]:
""" """
Parses the given SQL string into a collection of syntax trees, one per Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
parsed SQL statement.
Args: Args:
sql (str): the SQL code string to parse. sql: the SQL code string to parse.
read (str): the SQL dialect to apply during parsing read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
(eg. "spark", "hive", "presto", "mysql").
**opts: other options. **opts: other options.
Returns: Returns:
typing.List[Expression]: the list of parsed syntax trees. The resulting syntax tree collection.
""" """
dialect = Dialect.get_or_raise(read)() dialect = Dialect.get_or_raise(read)()
return dialect.parse(sql, **opts) return dialect.parse(sql, **opts)
def parse_one(sql, read=None, into=None, **opts): def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
into: t.Optional[Expression | str] = None,
**opts,
) -> t.Optional[Expression]:
""" """
Parses the given SQL string and returns a syntax tree for the first Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
parsed SQL statement.
Args: Args:
sql (str): the SQL code string to parse. sql: the SQL code string to parse.
read (str): the SQL dialect to apply during parsing read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
(eg. "spark", "hive", "presto", "mysql"). into: the SQLGlot Expression to parse into.
into (Expression): the SQLGlot Expression to parse into
**opts: other options. **opts: other options.
Returns: Returns:
Expression: the syntax tree for the first parsed statement. The syntax tree for the first parsed statement.
""" """
dialect = Dialect.get_or_raise(read)() dialect = Dialect.get_or_raise(read)()
@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts):
return result[0] if result else None return result[0] if result else None
def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts): def transpile(
sql: str,
read: t.Optional[str | Dialect] = None,
write: t.Optional[str | Dialect] = None,
identity: bool = True,
error_level: t.Optional[ErrorLevel] = None,
**opts,
) -> t.List[str]:
""" """
Parses the given SQL string using the source dialect and returns a list of SQL strings Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
transformed to conform to the target dialect. Each string in the returned list represents to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.
a single transformed SQL statement.
Args: Args:
sql (str): the SQL code string to transpile. sql: the SQL code string to transpile.
read (str): the source dialect used to parse the input string read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
(eg. "spark", "hive", "presto", "mysql"). write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
write (str): the target dialect into which the input should be transformed identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
(eg. "spark", "hive", "presto", "mysql"). the source and the target dialect.
identity (bool): if set to True and if the target dialect is not specified error_level: the desired error level of the parser.
the source dialect will be used as both: the source and the target dialect.
error_level (ErrorLevel): the desired error level of the parser.
**opts: other options. **opts: other options.
Returns: Returns:
typing.List[str]: the list of transpiled SQL statements / expressions. The list of transpiled SQL statements.
""" """
write = write or read if identity else write write = write or read if identity else write
return [ return [

View file

@ -49,7 +49,10 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()] error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse: if args.parse:
sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)] sqls = [
repr(expression)
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
]
else: else:
sqls = sqlglot.transpile( sqls = sqlglot.transpile(
args.sql, args.sql,

View file

@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar( ColumnLiterals = t.TypeVar(
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] "ColumnLiterals",
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
) )
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar( ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] "ColumnOrLiteral",
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
SchemaInput = t.TypeVar(
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
)
OutputExpressionContainer = t.TypeVar(
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
) )
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])

View file

@ -18,7 +18,11 @@ class Column:
expression = expression.expression # type: ignore expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)): elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore expression = self._lit(expression).expression # type: ignore
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")
self.expression: exp.Expression = expression
def __repr__(self): def __repr__(self):
return repr(self.expression) return repr(self.expression)
@ -135,21 +139,29 @@ class Column:
) -> Column: ) -> Column:
ensured_column = None if column is None else cls.ensure_col(column) ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = { ensure_expression_values = {
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression k: [Column.ensure_col(x).expression for x in v]
if is_iterable(v)
else Column.ensure_col(v).expression
for k, v in kwargs.items() for k, v in kwargs.items()
} }
new_expression = ( new_expression = (
callable_expression(**ensure_expression_values) callable_expression(**ensure_expression_values)
if ensured_column is None if ensured_column is None
else callable_expression(this=ensured_column.column_expression, **ensure_expression_values) else callable_expression(
this=ensured_column.column_expression, **ensure_expression_values
)
) )
return Column(new_expression) return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)) return Column(
klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
)
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)) return Column(
klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
)
def unary_op(self, klass: t.Callable, **kwargs) -> Column: def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs)) return Column(klass(this=self.column_expression, **kwargs))
@ -188,7 +200,7 @@ class Column:
expression.set("table", exp.to_identifier(table_name)) expression.set("table", exp.to_identifier(table_name))
return Column(expression) return Column(expression)
def sql(self, **kwargs) -> Column: def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs}) return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column: def alias(self, name: str) -> Column:
@ -265,10 +277,14 @@ class Column:
) )
def like(self, other: str): def like(self, other: str):
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression) return self.invoke_expression_over_column(
self, exp.Like, expression=self._lit(other).expression
)
def ilike(self, other: str): def ilike(self, other: str):
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression) return self.invoke_expression_over_column(
self, exp.ILike, expression=self._lit(other).expression
)
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
@ -287,10 +303,18 @@ class Column:
lowerBound: t.Union[ColumnOrLiteral], lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral], upperBound: t.Union[ColumnOrLiteral],
) -> Column: ) -> Column:
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound lower_bound_exp = (
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
)
upper_bound_exp = (
self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
)
return Column( return Column(
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression) exp.Between(
this=self.column_expression,
low=lower_bound_exp.expression,
high=upper_bound_exp.expression,
)
) )
def over(self, window: WindowSpec) -> Column: def over(self, window: WindowSpec) -> Column:

View file

@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer from sqlglot.dataframe.sql._typing import (
ColumnLiterals,
ColumnOrLiteral,
ColumnOrName,
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql.session import SparkSession
@ -83,7 +88,9 @@ class DataFrame:
return from_exp.alias_or_name return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias) table_alias = from_exp.find(exp.TableAlias)
if not table_alias: if not table_alias:
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}") raise RuntimeError(
f"Could not find an alias name for this expression: {self.expression}"
)
return table_alias.alias_or_name return table_alias.alias_or_name
return self.expression.ctes[-1].alias return self.expression.ctes[-1].alias
@ -132,12 +139,16 @@ class DataFrame:
cte.set("sequence_id", sequence_id or self.sequence_id) cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name return cte, name
def _ensure_list_of_columns( @t.overload
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]] def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
) -> t.List[Column]: ...
columns = ensure_list(cols)
columns = Column.ensure_cols(columns) @t.overload
return columns def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols): def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols) cols = self._ensure_list_of_columns(cols)
@ -153,10 +164,16 @@ class DataFrame:
df = self._resolve_pending_hints() df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy() expression = df.expression.copy()
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id) cte_expression, cte_name = df._create_cte_from_expression(
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression]) expression=expression, sequence_id=sequence_id
)
new_expression = df._add_ctes_to_expression(
exp.Select(), expression.ctes + [cte_expression]
)
sel_columns = df._get_outer_select_columns(cte_expression) sel_columns = df._get_outer_select_columns(cte_expression)
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns]) new_expression = new_expression.from_(cte_name).select(
*[x.alias_or_name for x in sel_columns]
)
return df.copy(expression=new_expression, sequence_id=sequence_id) return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame: def _resolve_pending_hints(self) -> DataFrame:
@ -169,16 +186,23 @@ class DataFrame:
hint_expression.args.get("expressions").append(hint) hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint) df.pending_hints.remove(hint)
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)} join_aliases = {
join_table.alias_or_name
for join_table in get_tables_from_expression_with_join(expression)
}
if join_aliases: if join_aliases:
for hint in df.pending_join_hints: for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions: for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name] sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name] sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
sequence_id_or_name
]
matching_ctes = [ matching_ctes = [
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match cte
for cte in reversed(expression.ctes)
if cte.args["sequence_id"] in sequence_ids_to_match
] ]
for matching_cte in matching_ctes: for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases: if matching_cte.alias_or_name in join_aliases:
@ -193,9 +217,14 @@ class DataFrame:
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper() hint_name = hint_name.upper()
hint_expression = ( hint_expression = (
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args]) exp.JoinHint(
this=hint_name,
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
)
if hint_name in JOIN_HINTS if hint_name in JOIN_HINTS
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args]) else exp.Anonymous(
this=hint_name, expressions=[parameter.expression for parameter in args]
)
) )
new_df = self.copy() new_df = self.copy()
new_df.pending_hints.append(hint_expression) new_df.pending_hints.append(hint_expression)
@ -245,7 +274,9 @@ class DataFrame:
def _get_select_expressions( def _get_select_expressions(
self, self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = [] select_expressions: t.List[
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
] = []
main_select_ctes: t.List[exp.CTE] = [] main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes: for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level") cache_storage_level = cte.args.get("cache_storage_level")
@ -279,14 +310,19 @@ class DataFrame:
cache_table_name = df._create_hash_from_expression(select_expression) cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name) cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"] original_alias_name = select_expression.args["cte_alias_name"]
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"] cache_storage_level = select_expression.args["cache_storage_level"]
options = [ options = [
exp.Literal.string("storageLevel"), exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level), exp.Literal.string(cache_storage_level),
] ]
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options) expression = exp.Cache(
this=cache_table, expression=select_expression, lazy=True, options=options
)
# We will drop the "view" if it exists before running the cache table # We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create: elif expression_type == exp.Create:
@ -305,7 +341,9 @@ class DataFrame:
raise ValueError(f"Invalid expression type: {expression_type}") raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression) output_expressions.append(expression)
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions] return [
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame: def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs)) return DataFrame(**object_to_dict(self, **kwargs))
@ -317,7 +355,9 @@ class DataFrame:
if self.expression.args.get("joins"): if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table] ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols: if ambiguous_cols:
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)] join_table_identifiers = [
x.this for x in get_tables_from_expression_with_join(self.expression)
]
cte_names_in_join = [x.this for x in join_table_identifiers] cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols: for ambiguous_col in ambiguous_cols:
ctes_with_column = [ ctes_with_column = [
@ -367,14 +407,20 @@ class DataFrame:
@operation(Operation.FROM) @operation(Operation.FROM)
def join( def join(
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs self,
other_df: DataFrame,
on: t.Union[str, t.List[str], Column, t.List[Column]],
how: str = "inner",
**kwargs,
) -> DataFrame: ) -> DataFrame:
other_df = other_df._convert_leaf_to_cte() other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on) columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ") join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column): if isinstance(columns[0].expression, exp.Column):
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns] join_columns = [
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
]
join_clause = functools.reduce( join_clause = functools.reduce(
lambda x, y: x & y, lambda x, y: x & y,
[ [
@ -402,7 +448,9 @@ class DataFrame:
for column in self._get_outer_select_columns(other_df) for column in self._get_outer_select_columns(other_df)
] ]
column_value_mapping = { column_value_mapping = {
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql(): column
for column in other_columns + self_columns + join_columns for column in other_columns + self_columns + join_columns
} }
all_columns = [ all_columns = [
@ -410,16 +458,22 @@ class DataFrame:
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
] ]
new_df = self.copy( new_df = self.copy(
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type) expression=self.expression.join(
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
)
)
new_df.expression = new_df._add_ctes_to_expression(
new_df.expression, other_df.expression.ctes
) )
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints) new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns) new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df return new_df
@operation(Operation.ORDER_BY) @operation(Operation.ORDER_BY)
def orderBy( def orderBy(
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None self,
*cols: t.Union[str, Column],
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
) -> DataFrame: ) -> DataFrame:
""" """
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
@ -429,7 +483,10 @@ class DataFrame:
columns = self._ensure_and_normalize_cols(cols) columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [ pre_ordered_col_indexes = [
x x
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)] for x in [
i if isinstance(col.expression, exp.Ordered) else None
for i, col in enumerate(columns)
]
if x is not None if x is not None
] ]
if ascending is None: if ascending is None:
@ -478,7 +535,9 @@ class DataFrame:
for r_column in r_columns_unused: for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column)) l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column) r_expressions.append(r_column)
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
)
l_df = self.copy() l_df = self.copy()
if allowMissingColumns: if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
@ -536,7 +595,9 @@ class DataFrame:
f"The minimum num nulls for dropna must be less than or equal to the number of columns. " f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
) )
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns] if_null_checks = [
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls") num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True) new_df = new_df.select(num_nulls, append=True)
@ -576,11 +637,15 @@ class DataFrame:
value_columns = [lit(value) for value in values] value_columns = [lit(value) for value in values]
null_replacement_mapping = { null_replacement_mapping = {
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)) column.alias_or_name: (
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
)
for column, value in zip(columns, value_columns) for column, value in zip(columns, value_columns)
} }
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns] null_replacement_columns = [
null_replacement_mapping[column.alias_or_name] for column in all_columns
]
new_df = new_df.select(*null_replacement_columns) new_df = new_df.select(*null_replacement_columns)
return new_df return new_df
@ -589,12 +654,11 @@ class DataFrame:
self, self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict], to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
subset: t.Optional[t.Union[str, t.List[str]]] = None, subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
) -> DataFrame: ) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit from sqlglot.dataframe.sql.functions import lit
old_values = None old_values = None
subset = ensure_list(subset)
new_df = self.copy() new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression) all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns} all_column_mapping = {column.alias_or_name: column for column in all_columns}
@ -605,7 +669,9 @@ class DataFrame:
new_values = list(to_replace.values()) new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list): elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list" assert isinstance(value, list), "value must be a list since the replacements are a list"
assert len(to_replace) == len(value), "the replacements and values must be the same length" assert len(to_replace) == len(
value
), "the replacements and values must be the same length"
old_values = to_replace old_values = to_replace
new_values = value new_values = value
else: else:
@ -635,7 +701,9 @@ class DataFrame:
def withColumn(self, colName: str, col: Column) -> DataFrame: def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col) col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects existing_col_names = self.expression.named_selects
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None existing_col_index = (
existing_col_names.index(colName) if colName in existing_col_names else None
)
if existing_col_index: if existing_col_index:
expression = self.expression.copy() expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression expression.expressions[existing_col_index] = col.expression
@ -645,7 +713,11 @@ class DataFrame:
@operation(Operation.SELECT) @operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str): def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy() expression = self.expression.copy()
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing] existing_columns = [
expression
for expression in expression.expressions
if expression.alias_or_name == existing
]
if not existing_columns: if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist") raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns: for existing_column in existing_columns:
@ -674,15 +746,19 @@ class DataFrame:
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters) parameter_list = ensure_list(parameters)
parameter_columns = ( parameter_columns = (
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id]) self._ensure_list_of_columns(parameter_list)
if parameters
else Column.ensure_cols([self.sequence_id])
) )
return self._hint(name, parameter_columns) return self._hint(name, parameter_columns)
@operation(Operation.NO_OP) @operation(Operation.NO_OP)
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame: def repartition(
num_partitions = Column.ensure_cols(ensure_list(numPartitions)) self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
) -> DataFrame:
num_partition_cols = self._ensure_list_of_columns(numPartitions)
columns = self._ensure_and_normalize_cols(cols) columns = self._ensure_and_normalize_cols(cols)
args = num_partitions + columns args = num_partition_cols + columns
return self._hint("repartition", args) return self._hint("repartition", args)
@operation(Operation.NO_OP) @operation(Operation.NO_OP)

View file

@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column: def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value) true_value = value if isinstance(value, Column) else lit(value)
return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)])) return Column(
glotexp.Case(
ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
)
)
def asc(col: ColumnOrName) -> Column: def asc(col: ColumnOrName) -> Column:
@ -407,7 +411,9 @@ def percentile_approx(
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
) )
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage)) return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage)
)
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FACTORIAL") return Column.invoke_anonymous_function(col, "FACTORIAL")
def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column: def lag(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
) -> Column:
if default is not None: if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default) return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1: if offset != 1:
@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
return Column.invoke_anonymous_function(col, "LAG") return Column.invoke_anonymous_function(col, "LAG")
def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column: def lead(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
) -> Column:
if default is not None: if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default) return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1: if offset != 1:
@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
return Column.invoke_anonymous_function(col, "LEAD") return Column.invoke_anonymous_function(col, "LEAD")
def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column: def nth_value(
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
) -> Column:
if ignoreNulls is not None: if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1: if offset != 1:
@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column: def months_between(
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
) -> Column:
if roundOff is None: if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
return Column.invoke_expression_over_column(col, glotexp.UnixToStr) return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column: def unix_timestamp(
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
) -> Column:
if format is not None: if format is not None:
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format)) return Column.invoke_expression_over_column(
timestamp, glotexp.StrToUnix, format=lit(format)
)
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix) return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
@ -642,7 +660,9 @@ def window(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
) )
if slideDuration is not None: if slideDuration is not None:
return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)) return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
)
if startTime is not None: if startTime is not None:
return Column.invoke_anonymous_function( return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column: def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)) return Column.invoke_expression_over_column(
None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
)
def decode(col: ColumnOrName, charset: str) -> Column: def decode(col: ColumnOrName, charset: str) -> Column:
@ -768,7 +790,9 @@ def overlay(
def sentences( def sentences(
string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None string: ColumnOrName,
language: t.Optional[ColumnOrName] = None,
country: t.Optional[ColumnOrName] = None,
) -> Column: ) -> Column:
if language is not None and country is not None: if language is not None and country is not None:
return Column.invoke_anonymous_function(string, "SENTENCES", language, country) return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr) substr_col = lit(substr)
if pos is not None: if pos is not None:
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos) return Column.invoke_expression_over_column(
str, glotexp.StrPosition, substr=substr_col, position=pos
)
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col) return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column( return Column.invoke_expression_over_column(
None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression None,
glotexp.VarMap,
keys=array(*cols[::2]).expression,
values=array(*cols[1::2]).expression,
) )
@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value) value_col = value if isinstance(value, Column) else lit(value)
return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression) return Column.invoke_expression_over_column(
col, glotexp.ArrayContains, expression=value_col.expression
)
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column: def slice(
x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
) -> Column:
start_col = start if isinstance(start, Column) else lit(start) start_col = start if isinstance(start, Column) else lit(start)
length_col = length if isinstance(length, Column) else lit(length) length_col = length if isinstance(length, Column) else lit(length)
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column: def array_join(
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
) -> Column:
if null_replacement is not None: if null_replacement is not None:
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)) return Column.invoke_anonymous_function(
col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
)
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
def concat(*cols: ColumnOrName) -> Column: def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1: if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT") return Column.invoke_anonymous_function(cols[0], "CONCAT")
return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]) return Column.invoke_anonymous_function(
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column: def sequence(
start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
) -> Column:
if step is not None: if step is not None:
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
return Column.invoke_anonymous_function(start, "SEQUENCE", stop) return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
@ -1103,12 +1144,15 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge) merge_exp = _get_lambda_from_func(merge)
if finish is not None: if finish is not None:
finish_exp = _get_lambda_from_func(finish) finish_exp = _get_lambda_from_func(finish)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)) return Column.invoke_anonymous_function(
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp)) return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform( def transform(
col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]] col: ColumnOrName,
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column: ) -> Column:
f_expression = _get_lambda_from_func(f) f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column: def filter(
col: ColumnOrName,
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f) f_expression = _get_lambda_from_func(f)
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression) return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column: def zip_with(
left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
) -> Column:
f_expression = _get_lambda_from_func(f) f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable): def _get_lambda_from_func(lambda_expression: t.Callable):
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames] variables = [
glotexp.to_identifier(x, quoted=_lambda_quoted(x))
for x in lambda_expression.__code__.co_varnames
]
return glotexp.Lambda( return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression, this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables, expressions=variables,

View file

@ -17,7 +17,9 @@ class GroupedData:
self.last_op = last_op self.last_op = last_op
self.group_by_cols = group_by_cols self.group_by_cols = group_by_cols
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]: def _get_function_applied_columns(
self, func_name: str, cols: t.Tuple[str, ...]
) -> t.List[Column]:
func_name = func_name.lower() func_name = func_name.lower()
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
@ -30,9 +32,9 @@ class GroupedData:
) )
cols = self._df._ensure_and_normalize_cols(columns) cols = self._df._ensure_and_normalize_cols(columns)
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select( expression = self._df.expression.group_by(
*[x.expression for x in self.group_by_cols + cols], append=False *[x.expression for x in self.group_by_cols]
) ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
return self._df.copy(expression=expression) return self._df.copy(expression=expression)
def count(self) -> DataFrame: def count(self) -> DataFrame:

View file

@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier): def replace_alias_name_with_cte_name(
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
if id.alias_or_name in spark.name_to_sequence_id_mapping: if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes): for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice # be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)] join_table_aliases = [
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases] x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
]
ctes_in_join = [
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2 assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name) _set_alias_name(id, ctes_in_join[0].alias_or_name)
@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
values = ensure_list(values)
results = [] results = []
for value in values: for value in values:
if isinstance(value, str): if isinstance(value, str):

View file

@ -19,12 +19,19 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName) sqlglot.schema.add_table(tableName)
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName))) return DataFrame(
self.spark,
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
)
class DataFrameWriter: class DataFrameWriter:
def __init__( def __init__(
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False self,
df: DataFrame,
spark: t.Optional[SparkSession] = None,
mode: t.Optional[str] = None,
by_name: bool = False,
): ):
self._df = df self._df = df
self._spark = spark or df.spark self._spark = spark or df.spark
@ -33,7 +40,10 @@ class DataFrameWriter:
def copy(self, **kwargs) -> DataFrameWriter: def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter( return DataFrameWriter(
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()} **{
k[1:] if k.startswith("_") else k: v
for k, v in object_to_dict(self, **kwargs).items()
}
) )
def sql(self, **kwargs) -> t.List[str]: def sql(self, **kwargs) -> t.List[str]:

View file

@ -67,13 +67,20 @@ class SparkSession:
data_expressions = [ data_expressions = [
exp.Tuple( exp.Tuple(
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values())) expressions=list(
map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
) )
for row in data for row in data
] ]
sel_columns = [ sel_columns = [
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
for name, data_type in column_mapping.items() for name, data_type in column_mapping.items()
] ]
@ -106,10 +113,12 @@ class SparkSession:
select_expression.set("with", expression.args.get("with")) select_expression.set("with", expression.args.get("with"))
expression.set("with", None) expression.set("with", None)
del expression.args["expression"] del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression) df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte() df = df._convert_leaf_to_cte()
else: else:
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.") raise ValueError(
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
)
return df return df
@property @property

View file

@ -158,7 +158,11 @@ class MapType(DataType):
class StructField(DataType): class StructField(DataType):
def __init__( def __init__(
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None self,
name: str,
dataType: DataType,
nullable: bool = True,
metadata: t.Optional[t.Dict[str, t.Any]] = None,
): ):
self.name = name self.name = name
self.dataType = dataType self.dataType = dataType

View file

@ -74,8 +74,13 @@ class WindowSpec:
window_spec.expression.args["order"].set("expressions", order_by) window_spec.expression.args["order"].set("expressions", order_by)
return window_spec return window_spec
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: def _calc_start_end(
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None} self, start: int, end: int
) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
"start_side": None,
"end_side": None,
}
if start == Window.currentRow: if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW" kwargs["start"] = "CURRENT ROW"
else: else:
@ -83,7 +88,9 @@ class WindowSpec:
**kwargs, **kwargs,
**{ **{
"start_side": "PRECEDING", "start_side": "PRECEDING",
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression, "start": "UNBOUNDED"
if start <= Window.unboundedPreceding
else F.lit(start).expression,
}, },
} }
if end == Window.currentRow: if end == Window.currentRow:
@ -93,7 +100,9 @@ class WindowSpec:
**kwargs, **kwargs,
**{ **{
"end_side": "FOLLOWING", "end_side": "FOLLOWING",
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression, "end": "UNBOUNDED"
if end >= Window.unboundedFollowing
else F.lit(end).expression,
}, },
} }
return kwargs return kwargs
@ -103,7 +112,10 @@ class WindowSpec:
spec = self._calc_start_end(start, end) spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS" spec["kind"] = "ROWS"
window_spec.expression.set( window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) "spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
) )
return window_spec return window_spec
@ -112,6 +124,9 @@ class WindowSpec:
spec = self._calc_start_end(start, end) spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE" spec["kind"] = "RANGE"
window_spec.expression.set( window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}) "spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
) )
return window_spec return window_spec

View file

@ -1,21 +1,21 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
inline_array_sql, inline_array_sql,
no_ilike_sql, no_ilike_sql,
rename_func, rename_func,
) )
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _date_add(expression_class): def _date_add(expression_class):
def func(args): def func(args):
interval = list_get(args, 1) interval = seq_get(args, 1)
return expression_class( return expression_class(
this=list_get(args, 0), this=seq_get(args, 0),
expression=interval.this, expression=interval.this,
unit=interval.args.get("unit"), unit=interval.args.get("unit"),
) )
@ -23,6 +23,13 @@ def _date_add(expression_class):
return func return func
def _date_trunc(args):
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(data_type, kind): def _date_add_sql(data_type, kind):
def func(self, expression): def func(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
structs = [] structs = []
for row in rows: for row in rows:
aliases = [ aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
] ]
structs.append(exp.Struct(expressions=aliases)) structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
@ -89,18 +97,19 @@ class BigQuery(Dialect):
"%j": "%-j", "%j": "%-j",
} }
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = [ QUOTES = [
(prefix + quote, quote) if prefix else quote (prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"] for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"] for prefix in ["", "r", "R"]
] ]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPE = "\\" ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")] HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
@ -111,35 +120,40 @@ class BigQuery(Dialect):
"WINDOW": TokenType.WINDOW, "WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE, "NOT DETERMINISTIC": TokenType.VOLATILE,
} }
KEYWORDS.pop("DIV")
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_ADD": _date_add(exp.TimeAdd), "TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub), "DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub), "DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub), "TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub), "TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)), "PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
} }
NO_PAREN_FUNCTIONS = { NO_PAREN_FUNCTIONS = {
**Parser.NO_PAREN_FUNCTIONS, **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIME: exp.CurrentTime,
} }
NESTED_TYPE_TOKENS = { NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS, *parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE, TokenType.TABLE,
} }
class Generator(Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -148,6 +162,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -157,11 +172,13 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest, exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql, exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql, exp.Create: _create_sql,
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64", exp.DataType.Type.INT: "INT64",

View file

@ -1,8 +1,9 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.generator import Generator from sqlglot.parser import parse_var_map
from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import TokenType
from sqlglot.tokens import Tokenizer, TokenType
def _lower_func(sql): def _lower_func(sql):
@ -14,11 +15,12 @@ class ClickHouse(Dialect):
normalize_functions = None normalize_functions = None
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"] IDENTIFIERS = ['"', "`"]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL, "FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME, "DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT, "INT8": TokenType.TINYINT,
@ -30,9 +32,9 @@ class ClickHouse(Dialect):
"TUPLE": TokenType.STRUCT, "TUPLE": TokenType.STRUCT,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"MAP": parse_var_map, "MAP": parse_var_map,
} }
@ -44,11 +46,11 @@ class ClickHouse(Dialect):
return this return this
class Generator(Generator): class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map", exp.DataType.Type.MAP: "Map",
@ -63,7 +65,7 @@ class ClickHouse(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark from sqlglot.dialects.spark import Spark
@ -15,7 +17,7 @@ class Databricks(Spark):
class Generator(Spark.Generator): class Generator(Spark.Generator):
TRANSFORMS = { TRANSFORMS = {
**Spark.Generator.TRANSFORMS, **Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql,
} }

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import typing as t
from enum import Enum from enum import Enum
from sqlglot import exp from sqlglot import exp
from sqlglot.generator import Generator from sqlglot.generator import Generator
from sqlglot.helper import flatten, list_get from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser from sqlglot.parser import Parser
from sqlglot.time import format_time from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer from sqlglot.tokens import Tokenizer
@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type): class _Dialect(type):
classes = {} classes: t.Dict[str, Dialect] = {}
@classmethod @classmethod
def __getitem__(cls, key): def __getitem__(cls, key):
@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator) klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] klass.identifier_start, klass.identifier_end = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[ klass.generator_class.TRANSFORMS[
exp.BitString exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[ klass.generator_class.TRANSFORMS[
exp.HexString exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[ klass.generator_class.TRANSFORMS[
exp.ByteString exp.ByteString
@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0 index_offset = 0
unnest_column_only = False unnest_column_only = False
alias_post_tablesample = False alias_post_tablesample = False
normalize_functions = "upper" normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small" null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'" date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'" dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'" time_format = "'%Y-%m-%d %H:%M:%S'"
time_mapping = {} time_mapping: t.Dict[str, str] = {}
# autofilled # autofilled
quote_start = None quote_start = None
@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end, "quote_end": self.quote_end,
"identifier_start": self.identifier_start, "identifier_start": self.identifier_start,
"identifier_end": self.identifier_end, "identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPE, "escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset, "index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping, "time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie, "time_trie": self.inverse_time_trie,
@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression): def if_sql(self, expression):
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})" return f"IF({expressions})"
@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args): def _format_time(args):
return exp_class( return exp_class(
this=list_get(args, 0), this=seq_get(args, 0),
format=Dialect[dialect].format_time( format=Dialect[dialect].format_time(
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
), ),
) )
@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions", "expressions",
[e for e in schema.expressions if e not in partitions], [e for e in schema.expressions if e not in partitions],
) )
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
expression.set("this", schema) expression.set("this", schema)
return self.create_sql(expression) return self.create_sql(expression)
@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None): def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args): def inner_func(args):
unit_based = len(args) == 3 unit_based = len(args) == 3
this = list_get(args, 2) if unit_based else list_get(args, 0) this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = list_get(args, 1) if unit_based else list_get(args, 1) expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit) return exp_class(this=this, expression=expression, unit=unit)

View file

@ -1,4 +1,6 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
approx_count_distinct_sql, approx_count_distinct_sql,
@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
rename_func, rename_func,
str_position_sql, str_position_sql,
) )
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _unix_to_time(self, expression): def _unix_to_time(self, expression):
@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
def _sort_array_reverse(args): def _sort_array_reverse(args):
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression): def _struct_pack_sql(self, expression):
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})" return f"STRUCT_PACK({', '.join(args)})"
@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
class DuckDB(Dialect): class DuckDB(Dialect):
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ, ":=": TokenType.EQ,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list,
@ -92,7 +95,7 @@ class DuckDB(Dialect):
"EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime( "EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div( this=exp.Div(
this=list_get(args, 0), this=seq_get(args, 0),
expression=exp.Literal.number(1000), expression=exp.Literal.number(1000),
) )
), ),
@ -112,11 +115,11 @@ class DuckDB(Dialect):
"UNNEST": exp.Explode.from_arg_list, "UNNEST": exp.Explode.from_arg_list,
} }
class Generator(Generator): class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"), exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySize: rename_func("ARRAY_LENGTH"),
@ -160,7 +163,7 @@ class DuckDB(Dialect):
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT",
} }

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
approx_count_distinct_sql, approx_count_distinct_sql,
@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql, struct_extract_sql,
var_map_sql, var_map_sql,
) )
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.parser import parse_var_map
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer
# (FuncType, Multiplier) # (FuncType, Multiplier)
DATE_DELTA_INTERVAL = { DATE_DELTA_INTERVAL = {
@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
unit = expression.text("unit").upper() unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = ( modified_increment = (
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression int(expression.text("expression")) * multiplier
if expression.expression.is_number
else expression.expression
) )
modified_increment = exp.Literal.number(modified_increment) modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})" return f"{func}({self.format_args(expression.this, modified_increment.this)})"
@ -165,10 +167,10 @@ class Hive(Dialect):
dateint_format = "'yyyyMMdd'" dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'" time_format = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"'] QUOTES = ["'", '"']
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPE = "\\" ESCAPES = ["\\"]
ENCODE = "utf-8" ENCODE = "utf-8"
NUMERIC_LITERALS = { NUMERIC_LITERALS = {
@ -180,40 +182,44 @@ class Hive(Dialect):
"BD": "DECIMAL", "BD": "DECIMAL",
} }
class Parser(Parser): class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd( "DATE_ADD": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0), this=seq_get(args, 0),
expression=list_get(args, 1), expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"), unit=exp.Literal.string("DAY"),
), ),
"DATEDIFF": lambda args: exp.DateDiff( "DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=list_get(args, 0)), this=exp.TsOrDsToDate(this=seq_get(args, 0)),
expression=exp.TsOrDsToDate(this=list_get(args, 1)), expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
), ),
"DATE_SUB": lambda args: exp.TsOrDsAdd( "DATE_SUB": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0), this=seq_get(args, 0),
expression=exp.Mul( expression=exp.Mul(
this=list_get(args, 1), this=seq_get(args, 1),
expression=exp.Literal.number(-1), expression=exp.Literal.number(-1),
), ),
unit=exp.Literal.string("DAY"), unit=exp.Literal.string("DAY"),
), ),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition( "LOCATE": lambda args: exp.StrPosition(
this=list_get(args, 1), this=seq_get(args, 1),
substr=list_get(args, 0), substr=seq_get(args, 0),
position=list_get(args, 2), position=seq_get(args, 2),
),
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
), ),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": parse_var_map, "MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE": exp.Quantile.from_arg_list,
@ -226,15 +232,16 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
} }
class Generator(Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.VARBINARY: "BINARY",
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, **transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql, exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql, exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayAgg: rename_func("COLLECT_LIST"),

View file

@ -1,4 +1,8 @@
from sqlglot import exp from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
no_ilike_sql, no_ilike_sql,
@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql, no_tablesample_sql,
no_trycast_sql, no_trycast_sql,
) )
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _show_parser(*args, **kwargs):
def _parse(self):
return self._parse_show_mysql(*args, **kwargs)
return _parse
def _date_trunc_sql(self, expression): def _date_trunc_sql(self, expression):
unit = expression.text("unit").lower() unit = expression.name.lower()
this = self.sql(expression.this) expr = self.sql(expression.expression)
if unit == "day": if unit == "day":
return f"DATE({this})" return f"DATE({expr})"
if unit == "week": if unit == "week":
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w" date_format = "%Y %u %w"
elif unit == "month": elif unit == "month":
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e" date_format = "%Y %c %e"
elif unit == "quarter": elif unit == "quarter":
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e" date_format = "%Y %c %e"
elif unit == "year": elif unit == "year":
concat = f"CONCAT(YEAR({this}), ' 1 1')" concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e" date_format = "%Y %c %e"
else: else:
self.unsupported("Unexpected interval unit: {unit}") self.unsupported("Unexpected interval unit: {unit}")
return f"DATE({this})" return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')" return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args): def _str_to_date(args):
date_format = MySQL.format_time(list_get(args, 1)) date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=list_get(args, 0), format=date_format) return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression): def _str_to_date_sql(self, expression):
@ -66,9 +75,9 @@ def _trim_sql(self, expression):
def _date_add(expression_class): def _date_add(expression_class):
def func(args): def func(args):
interval = list_get(args, 1) interval = seq_get(args, 1)
return expression_class( return expression_class(
this=list_get(args, 0), this=seq_get(args, 0),
expression=interval.this, expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()), unit=exp.Literal.string(interval.text("unit").lower()),
) )
@ -101,15 +110,16 @@ class MySQL(Dialect):
"%l": "%-I", "%l": "%-I",
} }
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"'] QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")] COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR, "SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER, "_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER,
@ -156,20 +166,23 @@ class MySQL(Dialect):
"_UTF32": TokenType.INTRODUCER, "_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER, "_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER, "_UTF8MB4": TokenType.INTRODUCER,
"@@": TokenType.SESSION_PARAMETER,
} }
class Parser(Parser): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd), "DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub), "DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date, "STR_TO_DATE": _str_to_date,
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression( "GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat, exp.GroupConcat,
this=self._parse_lambda(), this=self._parse_lambda(),
@ -178,15 +191,212 @@ class MySQL(Dialect):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS, **parser.Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
} }
class Generator(Generator): STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
SHOW_PARSERS = {
"BINARY LOGS": _show_parser("BINARY LOGS"),
"MASTER LOGS": _show_parser("BINARY LOGS"),
"BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
"CHARACTER SET": _show_parser("CHARACTER SET"),
"CHARSET": _show_parser("CHARACTER SET"),
"COLLATION": _show_parser("COLLATION"),
"FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
"COLUMNS": _show_parser("COLUMNS", target="FROM"),
"CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
"CREATE EVENT": _show_parser("CREATE EVENT", target=True),
"CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
"CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
"CREATE TABLE": _show_parser("CREATE TABLE", target=True),
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
"DATABASES": _show_parser("DATABASES"),
"ENGINE": _show_parser("ENGINE", target=True),
"STORAGE ENGINES": _show_parser("ENGINES"),
"ENGINES": _show_parser("ENGINES"),
"ERRORS": _show_parser("ERRORS"),
"EVENTS": _show_parser("EVENTS"),
"FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
"FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
"GRANTS": _show_parser("GRANTS", target="FOR"),
"INDEX": _show_parser("INDEX", target="FROM"),
"MASTER STATUS": _show_parser("MASTER STATUS"),
"OPEN TABLES": _show_parser("OPEN TABLES"),
"PLUGINS": _show_parser("PLUGINS"),
"PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
"PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
"PRIVILEGES": _show_parser("PRIVILEGES"),
"FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
"PROCESSLIST": _show_parser("PROCESSLIST"),
"PROFILE": _show_parser("PROFILE"),
"PROFILES": _show_parser("PROFILES"),
"RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
"REPLICAS": _show_parser("REPLICAS"),
"SLAVE HOSTS": _show_parser("REPLICAS"),
"REPLICA STATUS": _show_parser("REPLICA STATUS"),
"SLAVE STATUS": _show_parser("REPLICA STATUS"),
"GLOBAL STATUS": _show_parser("STATUS", global_=True),
"SESSION STATUS": _show_parser("STATUS"),
"STATUS": _show_parser("STATUS"),
"TABLE STATUS": _show_parser("TABLE STATUS"),
"FULL TABLES": _show_parser("TABLES", full=True),
"TABLES": _show_parser("TABLES"),
"TRIGGERS": _show_parser("TRIGGERS"),
"GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
"SESSION VARIABLES": _show_parser("VARIABLES"),
"VARIABLES": _show_parser("VARIABLES"),
"WARNINGS": _show_parser("WARNINGS"),
}
SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
}
PROFILE_TYPES = {
"ALL",
"BLOCK IO",
"CONTEXT SWITCHES",
"CPU",
"IPC",
"MEMORY",
"PAGE FAULTS",
"SOURCE",
"SWAPS",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
target_id = self._parse_id_var()
else:
target_id = None
log = self._parse_string() if self._match_text("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
return self.expression(
exp.Show,
this=this,
target=target_id,
full=full,
log=log,
position=position,
db=db,
channel=channel,
like=like,
where=where,
types=types,
query=query,
offset=offset,
limit=limit,
mutex=mutex,
**{"global": global_},
)
def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
return exp.Var(this=type_)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
return offset, limit
def _default_parse_set_item(self):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
right = self._parse_statement() or self._parse_id_var()
this = self.expression(
exp.EQ,
this=left,
expression=right,
)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_charset(self, kind):
this = self._parse_string() or self._parse_id_var()
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
collate=collate,
kind="NAMES",
)
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql, exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
@ -199,6 +409,8 @@ class MySQL(Dialect):
exp.StrToDate: _str_to_date_sql, exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql, exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql, exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
} }
ROOT_PROPERTIES = { ROOT_PROPERTIES = {
@ -209,4 +421,69 @@ class MySQL(Dialect):
exp.SchemaCommentProperty, exp.SchemaCommentProperty,
} }
WITH_PROPERTIES = {} WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
global_ = " GLOBAL" if expression.args.get("global") else ""
target = self.sql(expression, "target")
target = f" {target}" if target else ""
if expression.name in {"COLUMNS", "INDEX"}:
target = f" FROM{target}"
elif expression.name == "GRANTS":
target = f" FOR{target}"
db = self._prefixed_sql("FROM", expression, "db")
like = self._prefixed_sql("LIKE", expression, "like")
where = self.sql(expression, "where")
types = self.expressions(expression, key="types")
types = f" {types}" if types else types
query = self._prefixed_sql("FOR QUERY", expression, "query")
if expression.name == "PROFILE":
offset = self._prefixed_sql("OFFSET", expression, "offset")
limit = self._prefixed_sql("LIMIT", expression, "limit")
else:
offset = ""
limit = self._oldstyle_limit_sql(expression)
log = self._prefixed_sql("IN", expression, "log")
position = self._prefixed_sql("FROM", expression, "position")
channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
if expression.name == "ENGINE":
mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
else:
mutex_or_status = ""
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
def _prefixed_sql(self, prefix, expression, arg):
sql = self.sql(expression, arg)
if not sql:
return ""
return f" {prefix} {sql}"
def _oldstyle_limit_sql(self, expression):
limit = self.sql(expression, "limit")
offset = self.sql(expression, "offset")
if limit:
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
def setitem_sql(self, expression):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -1,8 +1,9 @@
from sqlglot import exp, transforms from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot.generator import Generator
from sqlglot.helper import csv from sqlglot.helper import csv
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import TokenType
def _limit_sql(self, expression): def _limit_sql(self, expression):
@ -36,9 +37,9 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015 "YYYY": "%Y", # 2015
} }
class Generator(Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER", exp.DataType.Type.INT: "NUMBER",
@ -49,11 +50,12 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql, exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -86,9 +88,9 @@ class Oracle(Dialect):
def table_sql(self, expression): def table_sql(self, expression):
return super().table_sql(expression, sep=" ") return super().table_sql(expression, sep=" ")
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR, "VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR,

View file

@ -1,4 +1,6 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
arrow_json_extract_scalar_sql, arrow_json_extract_scalar_sql,
@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql, no_trycast_sql,
str_position_sql, str_position_sql,
) )
from sqlglot.generator import Generator from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.transforms import delegate, preprocess from sqlglot.transforms import delegate, preprocess
@ -160,12 +160,12 @@ class Postgres(Dialect):
"YYYY": "%Y", # 2015 "YYYY": "%Y", # 2015
} }
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")] BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS, "ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT, "BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON, "COMMENT ON": TokenType.COMMENT_ON,
@ -179,31 +179,32 @@ class Postgres(Dialect):
} }
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
SINGLE_TOKENS = { SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS, **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER, "$": TokenType.PARAMETER,
} }
class Parser(Parser): class Parser(parser.Parser):
STRICT_CAST = False STRICT_CAST = False
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": _to_timestamp, "TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
} }
class Generator(Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA", exp.DataType.Type.BINARY: "BYTEA",
exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.DATETIME: "TIMESTAMP",
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.ColumnDef: preprocess( exp.ColumnDef: preprocess(
[ [
_auto_increment_to_serial, _auto_increment_to_serial,

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
format_time_lambda, format_time_lambda,
@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql, struct_extract_sql,
) )
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _approx_distinct_sql(self, expression): def _approx_distinct_sql(self, expression):
@ -110,30 +110,29 @@ class Presto(Dialect):
index_offset = 1 index_offset = 1
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'" time_format = "'%Y-%m-%d %H:%i:%S'"
time_mapping = MySQL.time_mapping time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
"ROW": TokenType.STRUCT, "ROW": TokenType.STRUCT,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd( "DATE_ADD": lambda args: exp.DateAdd(
this=list_get(args, 2), this=seq_get(args, 2),
expression=list_get(args, 1), expression=seq_get(args, 1),
unit=list_get(args, 0), unit=seq_get(args, 0),
), ),
"DATE_DIFF": lambda args: exp.DateDiff( "DATE_DIFF": lambda args: exp.DateDiff(
this=list_get(args, 2), this=seq_get(args, 2),
expression=list_get(args, 1), expression=seq_get(args, 1),
unit=list_get(args, 0), unit=seq_get(args, 0),
), ),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@ -143,7 +142,7 @@ class Presto(Dialect):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
} }
class Generator(Generator): class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
@ -159,7 +158,7 @@ class Presto(Dialect):
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY", exp.DataType.Type.BINARY: "VARBINARY",
@ -169,8 +168,8 @@ class Presto(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayConcat: rename_func("CONCAT"),

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
class Redshift(Postgres): class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'" time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = { time_mapping = {
**Postgres.time_mapping, **Postgres.time_mapping, # type: ignore
"MON": "%b", "MON": "%b",
"HH": "%H", "HH": "%H",
} }
class Tokenizer(Postgres.Tokenizer): class Tokenizer(Postgres.Tokenizer):
ESCAPE = "\\" ESCAPES = ["\\"]
KEYWORDS = { KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, **Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY, "GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH, "HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER, "SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP, "TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ, "TIMETZ": TokenType.TIMESTAMPTZ,
"VARBYTE": TokenType.BINARY, "VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO, "SIMILAR TO": TokenType.SIMILAR_TO,
} }
class Generator(Postgres.Generator): class Generator(Postgres.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING, **Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE", exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
} }

View file

@ -1,4 +1,6 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
format_time_lambda, format_time_lambda,
@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
rename_func, rename_func,
) )
from sqlglot.expressions import Literal from sqlglot.expressions import Literal
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
def _check_int(s): def _check_int(s):
@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ] # case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]: if second_arg.name not in ["0", "3", "9"]:
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0": if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS timescale = exp.UnixToTime.SECONDS
@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime(this=first_arg, scale=timescale) return exp.UnixToTime(this=first_arg, scale=timescale)
first_arg = list_get(args, 0) first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal): if not isinstance(first_arg, Literal):
# case: <variant_expr> # case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args) return exp.UnixToTime.from_arg_list(args)
def _unix_to_time(self, expression): def _unix_to_time_sql(self, expression):
scale = expression.args.get("scale") scale = expression.args.get("scale")
timestamp = self.sql(expression, "this") timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]: if scale in [None, exp.UnixToTime.SECONDS]:
@ -132,9 +134,9 @@ class Snowflake(Dialect):
"ff6": "%f", "ff6": "%f",
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list, "IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp, "TO_TIMESTAMP": _snowflake_to_timestamp,
@ -143,18 +145,18 @@ class Snowflake(Dialect):
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part, "DATE_PART": _parse_date_part,
} }
FUNC_TOKENS = { FUNC_TOKENS = {
*Parser.FUNC_TOKENS, *parser.Parser.FUNC_TOKENS,
TokenType.RLIKE, TokenType.RLIKE,
TokenType.TABLE, TokenType.TABLE,
} }
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS, **parser.Parser.COLUMN_OPERATORS, # type: ignore
TokenType.COLON: lambda self, this, path: self.expression( TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket, exp.Bracket,
this=this, this=this,
@ -163,21 +165,21 @@ class Snowflake(Dialect):
} }
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS, **parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
} }
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
ESCAPE = "\\" ESCAPES = ["\\"]
SINGLE_TOKENS = { SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS, **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER, "$": TokenType.PARAMETER,
} }
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY, "QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE, "DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
@ -187,15 +189,15 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE, "SAMPLE": TokenType.TABLE_SAMPLE,
} }
class Generator(Generator): class Generator(generator.Generator):
CREATE_TRANSIENT = True CREATE_TRANSIENT = True
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"), exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time, exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql, exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"), exp.StrPosition: rename_func("POSITION"),
@ -204,7 +206,7 @@ class Snowflake(Dialect):
} }
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
} }

View file

@ -1,8 +1,9 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get from sqlglot.helper import seq_get
from sqlglot.parser import Parser
def _create_sql(self, e): def _create_sql(self, e):
@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive): class Spark(Hive):
class Parser(Hive.Parser): class Parser(Hive.Parser):
FUNCTIONS = { FUNCTIONS = {
**Hive.Parser.FUNCTIONS, **Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list, "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring( "LEFT": lambda args: exp.Substring(
this=list_get(args, 0), this=seq_get(args, 0),
start=exp.Literal.number(1), start=exp.Literal.number(1),
length=list_get(args, 1), length=seq_get(args, 1),
), ),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift( "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=list_get(args, 0), this=seq_get(args, 0),
expression=list_get(args, 1), expression=seq_get(args, 1),
), ),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift( "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=list_get(args, 0), this=seq_get(args, 0),
expression=list_get(args, 1), expression=seq_get(args, 1),
), ),
"RIGHT": lambda args: exp.Substring( "RIGHT": lambda args: exp.Substring(
this=list_get(args, 0), this=seq_get(args, 0),
start=exp.Sub( start=exp.Sub(
this=exp.Length(this=list_get(args, 0)), this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
), ),
length=list_get(args, 1), length=seq_get(args, 1),
), ),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list, "IIF": exp.If.from_arg_list,
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS, **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator): class Generator(Hive.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING, **Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG", exp.DataType.Type.BIGINT: "LONG",
} }
TRANSFORMS = { TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateFromParts: rename_func("MAKE_DATE"),
} }
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False WRAP_DERIVED_VALUES = False

View file

@ -1,4 +1,6 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
arrow_json_extract_scalar_sql, arrow_json_extract_scalar_sql,
@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
no_trycast_sql, no_trycast_sql,
rename_func, rename_func,
) )
from sqlglot.generator import Generator from sqlglot.tokens import TokenType
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
class SQLite(Dialect): class SQLite(Dialect):
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"] IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT, "AUTOINCREMENT": TokenType.AUTO_INCREMENT,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list, "EDITDIST3": exp.Levenshtein.from_arg_list,
} }
class Generator(Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER",
@ -46,6 +45,7 @@ class SQLite(Dialect):
exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
} }
TOKEN_MAPPING = { TOKEN_MAPPING = {
@ -53,7 +53,7 @@ class SQLite(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL): class StarRocks(MySQL):
class Generator(MySQL.Generator): class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = { TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING, **MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TEXT: "STRING",
@ -13,7 +15,7 @@ class StarRocks(MySQL):
} }
TRANSFORMS = { TRANSFORMS = {
**MySQL.Generator.TRANSFORMS, **MySQL.Generator.TRANSFORMS, # type: ignore
exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"), exp.DateDiff: rename_func("DATEDIFF"),
@ -22,3 +24,4 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTime: rename_func("FROM_UNIXTIME"),
} }
TRANSFORMS.pop(exp.DateTrunc)

View file

@ -1,7 +1,7 @@
from sqlglot import exp from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.parser import Parser
def _if_sql(self, expression): def _if_sql(self, expression):
@ -20,17 +20,17 @@ def _count_sql(self, expression):
class Tableau(Dialect): class Tableau(Dialect):
class Generator(Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql, exp.If: _if_sql,
exp.Coalesce: _coalesce_sql, exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql, exp.Count: _count_sql,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list, "IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
} }

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp from sqlglot import exp
from sqlglot.dialects.presto import Presto from sqlglot.dialects.presto import Presto
@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto): class Trino(Presto):
class Generator(Presto.Generator): class Generator(Presto.Generator):
TRANSFORMS = { TRANSFORMS = {
**Presto.Generator.TRANSFORMS, **Presto.Generator.TRANSFORMS, # type: ignore
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
} }

View file

@ -1,15 +1,22 @@
from __future__ import annotations
import re import re
from sqlglot import exp from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType from sqlglot.expressions import DataType
from sqlglot.generator import Generator from sqlglot.helper import seq_get
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.time import format_time from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import TokenType
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
"w": "%A",
"month": "%B",
"mm": "%B",
"m": "%B",
}
DATE_DELTA_INTERVAL = { DATE_DELTA_INTERVAL = {
"year": "year", "year": "year",
"yyyy": "year", "yyyy": "year",
@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args): def _format_time(args):
return exp_class( return exp_class(
this=list_get(args, 1), this=seq_get(args, 1),
format=exp.Literal.string( format=exp.Literal.string(
format_time( format_time(
list_get(args, 0).name or (TSQL.time_format if default is True else default), seq_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
) )
), ),
) )
@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def parse_format(args): def parse_format(args):
fmt = list_get(args, 1) fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt: if number_fmt:
return exp.NumberToStr(this=list_get(args, 0), format=fmt) return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.TimeToStr( return exp.TimeToStr(
this=list_get(args, 0), this=seq_get(args, 0),
format=exp.Literal.string( format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping) format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1 if len(fmt.name) == 1
@ -188,11 +197,11 @@ class TSQL(Dialect):
"Y": "%a %Y", "Y": "%a %Y",
} }
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")] IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = { KEYWORDS = {
**Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN, "BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT, "REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT, "NTEXT": TokenType.TEXT,
@ -200,7 +209,6 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME, "DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP, "TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE, "IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY, "MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY, "SMALLMONEY": TokenType.SMALLMONEY,
@ -213,9 +221,9 @@ class TSQL(Dialect):
"TOP": TokenType.TOP, "TOP": TokenType.TOP,
} }
class Parser(Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**Parser.FUNCTIONS, **parser.Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list, "CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -243,14 +251,16 @@ class TSQL(Dialect):
this = self._parse_column() this = self._parse_column()
# Retrieve length of datatype and override to default if not specified # Retrieve length of datatype and override to default if not specified
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable # Check whether a conversion with format is applicable
if self._match(TokenType.COMMA): if self._match(TokenType.COMMA):
format_val = self._parse_number().name format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping: if format_val not in TSQL.convert_format_mapping:
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val}"
)
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format # Check whether the convert entails a string to date format
@ -272,9 +282,9 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement # Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
class Generator(Generator): class Generator(generator.Generator):
TYPE_MAPPING = { TYPE_MAPPING = {
**Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DECIMAL: "NUMERIC",
@ -283,7 +293,7 @@ class TSQL(Dialect):
} }
TRANSFORMS = { TRANSFORMS = {
**Generator.TRANSFORMS, **generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"), exp.CurrentDate: rename_func("GETDATE"),

View file

@ -4,7 +4,7 @@ from heapq import heappop, heappush
from sqlglot import Dialect from sqlglot import Dialect
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.helper import ensure_list from sqlglot.helper import ensure_collection
@dataclass(frozen=True) @dataclass(frozen=True)
@ -116,7 +116,9 @@ class ChangeDistiller:
source_node = self._source_index[kept_source_node_id] source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id] target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set)) edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
edit_script.append(Keep(source_node, target_node)) edit_script.append(Keep(source_node, target_node))
else: else:
edit_script.append(Update(source_node, target_node)) edit_script.append(Update(source_node, target_node))
@ -158,13 +160,16 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num: if max_leaves_num:
common_leaves_num = sum( common_leaves_num = sum(
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set 1 if s in source_leaf_ids and t in target_leaf_ids else 0
for s, t in leaves_matching_set
) )
leaf_similarity_score = common_leaves_num / max_leaves_num leaf_similarity_score = common_leaves_num / max_leaves_num
else: else:
leaf_similarity_score = 0.0 leaf_similarity_score = 0.0
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 adjusted_t = (
self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
)
if leaf_similarity_score >= 0.8 or ( if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t leaf_similarity_score >= adjusted_t
@ -201,7 +206,10 @@ class ChangeDistiller:
matching_set = set() matching_set = set()
while candidate_matchings: while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings) _, _, source_leaf, target_leaf = heappop(candidate_matchings)
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes: if (
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
):
matching_set.add((id(source_leaf), id(target_leaf))) matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf)) self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf)) self._unmatched_target_nodes.remove(id(target_leaf))
@ -241,8 +249,7 @@ def _get_leaves(expression):
has_child_exprs = False has_child_exprs = False
for a in expression.args.values(): for a in expression.args.values():
nodes = ensure_list(a) for node in ensure_collection(a):
for node in nodes:
if isinstance(node, exp.Expression): if isinstance(node, exp.Expression):
has_child_exprs = True has_child_exprs = True
yield from _get_leaves(node) yield from _get_leaves(node)
@ -268,7 +275,7 @@ def _expression_only_args(expression):
args = [] args = []
if expression: if expression:
for a in expression.args.values(): for a in expression.args.values():
args.extend(ensure_list(a)) args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)] return [a for a in args if isinstance(a, exp.Expression)]

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from enum import auto from enum import auto
from sqlglot.helper import AutoName from sqlglot.helper import AutoName
@ -30,7 +33,11 @@ class OptimizeError(SqlglotError):
pass pass
def concat_errors(errors, maximum): class SchemaError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]] msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum remaining = len(errors) - maximum
if remaining > 0: if remaining > 0:

View file

@ -19,6 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context env (Optional[dict]): dictionary of functions within the execution context
""" """
self.tables = tables self.tables = tables
self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers} self.env = {**(env or {}), "scope": self.row_readers}
@ -29,8 +30,27 @@ class Context:
def eval_tuple(self, codes): def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes) return tuple(self.eval(code) for code in codes)
@property
def table(self):
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
return self._table
@property
def columns(self):
return self.table.columns
def __iter__(self): def __iter__(self):
return self.table_iter(list(self.tables)[0]) self.env["scope"] = self.row_readers
for i in range(len(self.table.rows)):
for table in self.tables.values():
reader = table[i]
yield reader, self
def table_iter(self, table): def table_iter(self, table):
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
@ -38,8 +58,8 @@ class Context:
for reader in self.tables[table]: for reader in self.tables[table]:
yield reader, self yield reader, self
def sort(self, table, key): def sort(self, key):
table = self.tables[table] table = self.table
def sort_key(row): def sort_key(row):
table.reader.row = row table.reader.row = row
@ -47,20 +67,20 @@ class Context:
table.rows.sort(key=sort_key) table.rows.sort(key=sort_key)
def set_row(self, table, row): def set_row(self, row):
self.row_readers[table].row = row for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
def set_index(self, table, index): def set_index(self, index):
self.row_readers[table].row = self.tables[table].rows[index] for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers self.env["scope"] = self.row_readers
def set_range(self, table, start, end): def set_range(self, start, end):
self.range_readers[table].range = range(start, end) for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers self.env["scope"] = self.range_readers
def __getitem__(self, table):
return self.env["scope"][table]
def __contains__(self, table): def __contains__(self, table):
return table in self.tables return table in self.tables

View file

@ -2,6 +2,8 @@ import datetime
import re import re
import statistics import statistics
from sqlglot.helper import PYTHON_VERSION
class reverse_key: class reverse_key:
def __init__(self, obj): def __init__(self, obj):
@ -25,7 +27,7 @@ ENV = {
"str": str, "str": str,
"desc": reverse_key, "desc": reverse_key,
"SUM": sum, "SUM": sum,
"AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
"COUNT": lambda acc: sum(1 for e in acc if e is not None), "COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max, "MAX": max,
"MIN": min, "MIN": min,

View file

@ -1,15 +1,14 @@
import ast import ast
import collections import collections
import itertools import itertools
import math
from sqlglot import exp, planner from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table from sqlglot.executor.table import Table
from sqlglot.generator import Generator
from sqlglot.helper import csv_reader from sqlglot.helper import csv_reader
from sqlglot.tokens import Tokenizer
class PythonExecutor: class PythonExecutor:
@ -26,7 +25,11 @@ class PythonExecutor:
while queue: while queue:
node = queue.pop() node = queue.pop()
context = self.context( context = self.context(
{name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()} {
name: table
for dep in node.dependencies
for name, table in contexts[dep].tables.items()
}
) )
running.add(node) running.add(node)
@ -76,13 +79,10 @@ class PythonExecutor:
return Table(expression.alias_or_name for expression in expressions) return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context): def scan(self, step, context):
if hasattr(step, "source"):
source = step.source source = step.source
if isinstance(source, exp.Expression): if isinstance(source, exp.Expression):
source = source.name or source.alias source = source.name or source.alias
else:
source = step.name
condition = self.generate(step.condition) condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections) projections = self.generate_tuple(step.projections)
@ -96,14 +96,12 @@ class PythonExecutor:
if projections: if projections:
sink = self.table(step.projections) sink = self.table(step.projections)
elif source in context:
sink = Table(context[source].columns)
else: else:
sink = None sink = None
for reader, ctx in table_iter: for reader, ctx in table_iter:
if sink is None: if sink is None:
sink = Table(ctx[source].columns) sink = Table(reader.columns)
if condition and not ctx.eval(condition): if condition and not ctx.eval(condition):
continue continue
@ -135,119 +133,102 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v))) types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError): except (ValueError, SyntaxError):
types.append(str) types.append(str)
context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) context.set_row(tuple(t(v) for t, v in zip(types, row)))
yield context[alias], context yield context.table.reader, context
def join(self, step, context): def join(self, step, context):
source = step.name source = step.name
join_context = self.context({source: context.tables[source]}) source_table = context.tables[source]
source_context = self.context({source: source_table})
def merge_context(ctx, table): column_ranges = {source: range(0, len(source_table.columns))}
# create a new context where all existing tables are mapped to a new one
return self.context({name: table for name in ctx.tables})
for name, join in step.joins.items(): for name, join in step.joins.items():
join_context = self.context({**join_context.tables, name: context.tables[name]}) table = context.tables[name]
start = max(r.stop for r in column_ranges.values())
column_ranges[name] = range(start, len(table.columns) + start)
join_context = self.context({name: table})
if join.get("source_key"): if join.get("source_key"):
table = self.hash_join(join, source, name, join_context) table = self.hash_join(join, source_context, join_context)
else: else:
table = self.nested_loop_join(join, source, name, join_context) table = self.nested_loop_join(join, source_context, join_context)
join_context = merge_context(join_context, table) source_context = self.context(
{
name: Table(table.columns, table.rows, column_range)
for name, column_range in column_ranges.items()
}
)
# apply projections or conditions condition = self.generate(step.condition)
context = self.scan(step, join_context) projections = self.generate_tuple(step.projections)
# use the scan context since it returns a single table if not condition or not projections:
# otherwise there are no projections so all other tables are still in scope return source_context
if step.projections:
return context
return merge_context(join_context, context.tables[source]) sink = self.table(step.projections if projections else source_context.columns)
def nested_loop_join(self, _join, a, b, context): for reader, ctx in join_context:
table = Table(context.tables[a].columns + context.tables[b].columns) if condition and not ctx.eval(condition):
continue
for reader_a, _ in context.table_iter(a): if projections:
for reader_b, _ in context.table_iter(b): sink.append(ctx.eval_tuple(projections))
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink})
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
for reader_a, _ in source_context:
for reader_b, _ in join_context:
table.append(reader_a.row + reader_b.row) table.append(reader_a.row + reader_b.row)
return table return table
def hash_join(self, join, a, b, context): def hash_join(self, join, source_context, join_context):
a_key = self.generate_tuple(join["source_key"]) source_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"]) join_key = self.generate_tuple(join["join_key"])
results = collections.defaultdict(lambda: ([], [])) results = collections.defaultdict(lambda: ([], []))
for reader, ctx in context.table_iter(a): for reader, ctx in source_context:
results[ctx.eval_tuple(a_key)][0].append(reader.row) results[ctx.eval_tuple(source_key)][0].append(reader.row)
for reader, ctx in context.table_iter(b): for reader, ctx in join_context:
results[ctx.eval_tuple(b_key)][1].append(reader.row) results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
table = Table(context.tables[a].columns + context.tables[b].columns)
for a_group, b_group in results.values(): for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group): for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row) table.append(a_row + b_row)
return table return table
def sort_merge_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
context.sort(a, a_key)
context.sort(b, b_key)
a_i = 0
b_i = 0
a_n = len(context.tables[a])
b_n = len(context.tables[b])
table = Table(context.tables[a].columns + context.tables[b].columns)
def get_key(source, key, i):
context.set_index(source, i)
return context.eval_tuple(key)
while a_i < a_n and b_i < b_n:
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
a_group = []
while a_i < a_n and key == get_key(a, a_key, a_i):
a_group.append(context[a].row)
a_i += 1
b_group = []
while b_i < b_n and key == get_key(b, b_key, b_i):
b_group.append(context[b].row)
b_i += 1
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context): def aggregate(self, step, context):
source = step.source source = step.source
group_by = self.generate_tuple(step.group) group_by = self.generate_tuple(step.group)
aggregations = self.generate_tuple(step.aggregations) aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands) operands = self.generate_tuple(step.operands)
context.sort(source, group_by) if operands:
if step.operands:
source_table = context.tables[source] source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns) operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context: for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands)) operand_table.append(reader.row + ctx.eval_tuple(operands))
context = self.context({source: operand_table}) context = self.context(
{None: operand_table, **{table: operand_table for table in context.tables}}
)
context.sort(group_by)
group = None group = None
start = 0 start = 0
@ -256,15 +237,15 @@ class PythonExecutor:
table = self.table(step.group + step.aggregations) table = self.table(step.group + step.aggregations)
for i in range(length): for i in range(length):
context.set_index(source, i) context.set_index(i)
key = context.eval_tuple(group_by) key = context.eval_tuple(group_by)
group = key if group is None else group group = key if group is None else group
end += 1 end += 1
if i == length - 1: if i == length - 1:
context.set_range(source, start, end - 1) context.set_range(start, end - 1)
elif key != group: elif key != group:
context.set_range(source, start, end - 2) context.set_range(start, end - 2)
else: else:
continue continue
@ -272,13 +253,32 @@ class PythonExecutor:
group = key group = key
start = end - 2 start = end - 2
return self.scan(step, self.context({source: table})) context = self.context({step.name: table, **{name: table for name in context.tables}})
if step.projections:
return self.scan(step, context)
return context
def sort(self, step, context): def sort(self, step, context):
table = list(context.tables)[0] projections = self.generate_tuple(step.projections)
key = self.generate_tuple(step.key)
context.sort(table, key) sink = self.table(step.projections)
return self.scan(step, context)
for reader, ctx in context:
sink.append(ctx.eval_tuple(projections))
context = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
context.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
context.table.rows = context.table.rows[0 : step.limit]
return self.context({step.name: context.table})
def _cast_py(self, expression): def _cast_py(self, expression):
@ -293,7 +293,7 @@ def _cast_py(self, expression):
def _column_py(self, expression): def _column_py(self, expression):
table = self.sql(expression, "table") table = self.sql(expression, "table") or None
this = self.sql(expression, "this") this = self.sql(expression, "this")
return f"scope[{table}][{this}]" return f"scope[{table}][{this}]"
@ -319,10 +319,10 @@ def _ordered_py(self, expression):
class Python(Dialect): class Python(Dialect):
class Tokenizer(Tokenizer): class Tokenizer(tokens.Tokenizer):
ESCAPE = "\\" ESCAPES = ["\\"]
class Generator(Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this), exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql, exp.Array: inline_array_sql,

View file

@ -1,10 +1,12 @@
class Table: class Table:
def __init__(self, *columns, rows=None): def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
self.rows = rows or [] self.rows = rows or []
if rows: if rows:
assert len(rows[0]) == len(self.columns) assert len(rows[0]) == len(self.columns)
self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self) self.range_reader = RangeReader(self)
def append(self, row): def append(self, row):
@ -29,15 +31,22 @@ class Table:
return self.reader return self.reader
def __repr__(self): def __repr__(self):
widths = {column: len(column) for column in self.columns} columns = tuple(
lines = [" ".join(column for column in self.columns)] column
for i, column in enumerate(self.columns)
if not self.column_range or i in self.column_range
)
widths = {column: len(column) for column in columns}
lines = [" ".join(column for column in columns)]
for i, row in enumerate(self): for i, row in enumerate(self):
if i > 10: if i > 10:
break break
lines.append( lines.append(
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns) " ".join(
str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
)
) )
return "\n".join(lines) return "\n".join(lines)
@ -70,8 +79,10 @@ class RangeReader:
class RowReader: class RowReader:
def __init__(self, columns): def __init__(self, columns, column_range=None):
self.columns = {column: i for i, column in enumerate(columns)} self.columns = {
column: i for i, column in enumerate(columns) if not column_range or i in column_range
}
self.row = None self.row = None
def __getitem__(self, column): def __getitem__(self, column):

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import datetime import datetime
import numbers import numbers
import re import re
import typing as t
from collections import deque from collections import deque
from copy import deepcopy from copy import deepcopy
from enum import auto from enum import auto
@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
from sqlglot.helper import ( from sqlglot.helper import (
AutoName, AutoName,
camel_to_snake_case, camel_to_snake_case,
ensure_list, ensure_collection,
list_get, seq_get,
split_num_words, split_num_words,
subclasses, subclasses,
) )
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
class _Expression(type): class _Expression(type):
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs):
@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
or optional (False). or optional (False).
""" """
key = None key = "Expression"
arg_types = {"this": True} arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type") __slots__ = ("args", "parent", "arg_key", "type", "comment")
def __init__(self, **args): def __init__(self, **args):
self.args = args self.args = args
self.parent = None self.parent = None
self.arg_key = None self.arg_key = None
self.type = None self.type = None
self.comment = None
for arg_key, value in self.args.items(): for arg_key, value in self.args.items():
self._set_parent(arg_key, value) self._set_parent(arg_key, value)
def __eq__(self, other): def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other) return type(self) is type(other) and _norm_args(self) == _norm_args(other)
def __hash__(self): def __hash__(self) -> int:
return hash( return hash(
( (
self.key, self.key,
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), tuple(
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
),
) )
) )
@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
return field.this return field.this
return "" return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property @property
def is_string(self): def is_string(self):
return isinstance(self, Literal) and self.args["is_string"] return isinstance(self, Literal) and self.args["is_string"]
@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
return self.alias or self.name return self.alias or self.name
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__class__(**deepcopy(self.args)) copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.type = self.type
return copy
def copy(self): def copy(self):
new = deepcopy(self) new = deepcopy(self)
@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
return return
for k, v in self.args.items(): for k, v in self.args.items():
nodes = ensure_list(v) for node in ensure_collection(v):
for node in nodes:
if isinstance(node, Expression): if isinstance(node, Expression):
yield from node.dfs(self, k, prune) yield from node.dfs(self, k, prune)
@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
if isinstance(item, Expression): if isinstance(item, Expression):
for k, v in item.args.items(): for k, v in item.args.items():
nodes = ensure_list(v) for node in ensure_collection(v):
for node in nodes:
if isinstance(node, Expression): if isinstance(node, Expression):
queue.append((node, item, k)) queue.append((node, item, k))
@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
def __repr__(self): def __repr__(self):
return self.to_s() return self.to_s()
def sql(self, dialect=None, **opts): def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
""" """
Returns SQL string representation of this tree. Returns SQL string representation of this tree.
@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect)().generate(self, **opts) return Dialect.get_or_raise(dialect)().generate(self, **opts)
def to_s(self, hide_missing=True, level=0): def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n" indent = "" if not level else "\n"
indent += "".join([" "] * level) indent += "".join([" "] * level)
left = f"({self.key.upper()} " left = f"({self.key.upper()} "
@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
args = { args = {
k: ", ".join( k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs) for v in ensure_collection(vs)
if v is not None if v is not None
) )
for k, vs in self.args.items() for k, vs in self.args.items()
} }
args["comment"] = self.comment
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing} args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items()) right = ", ".join(f"{k}: {v}" for k, v in args.items())
@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
pass pass
class Annotation(Expression):
arg_types = {
"this": True,
"expression": True,
}
@property
def alias(self):
return self.expression.alias_or_name
class Cache(Expression): class Cache(Expression):
arg_types = { arg_types = {
"with": False, "with": False,
@ -623,6 +635,38 @@ class Describe(Expression):
pass pass
class Set(Expression):
arg_types = {"expressions": True}
class SetItem(Expression):
arg_types = {
"this": True,
"kind": False,
"collate": False, # MySQL SET NAMES statement
}
class Show(Expression):
arg_types = {
"this": True,
"target": False,
"offset": False,
"limit": False,
"like": False,
"where": False,
"db": False,
"full": False,
"mutex": False,
"query": False,
"channel": False,
"global": False,
"log": False,
"position": False,
"types": False,
}
class UserDefinedFunction(Expression): class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False} arg_types = {"this": True, "expressions": False}
@ -864,18 +908,20 @@ class Literal(Condition):
def __eq__(self, other): def __eq__(self, other):
return ( return (
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
) )
def __hash__(self): def __hash__(self):
return hash((self.key, self.this, self.args["is_string"])) return hash((self.key, self.this, self.args["is_string"]))
@classmethod @classmethod
def number(cls, number): def number(cls, number) -> Literal:
return cls(this=str(number), is_string=False) return cls(this=str(number), is_string=False)
@classmethod @classmethod
def string(cls, string): def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True) return cls(this=str(string), is_string=True)
@ -1087,7 +1133,7 @@ class Properties(Expression):
} }
@classmethod @classmethod
def from_dict(cls, properties_dict): def from_dict(cls, properties_dict) -> Properties:
expressions = [] expressions = []
for key, value in properties_dict.items(): for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
@ -1323,7 +1369,7 @@ class Select(Subqueryable):
**QUERY_MODIFIERS, **QUERY_MODIFIERS,
} }
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Set the FROM expression. Set the FROM expression.
@ -1356,7 +1402,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Set the GROUP BY expression. Set the GROUP BY expression.
@ -1392,7 +1438,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Set the ORDER BY expression. Set the ORDER BY expression.
@ -1425,7 +1471,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Set the SORT BY expression. Set the SORT BY expression.
@ -1458,7 +1504,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Set the CLUSTER BY expression. Set the CLUSTER BY expression.
@ -1491,7 +1537,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def limit(self, expression, dialect=None, copy=True, **opts): def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
""" """
Set the LIMIT expression. Set the LIMIT expression.
@ -1522,7 +1568,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def offset(self, expression, dialect=None, copy=True, **opts): def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
""" """
Set the OFFSET expression. Set the OFFSET expression.
@ -1553,7 +1599,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def select(self, *expressions, append=True, dialect=None, copy=True, **opts): def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Append to or set the SELECT expressions. Append to or set the SELECT expressions.
@ -1583,7 +1629,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Append to or set the LATERAL expressions. Append to or set the LATERAL expressions.
@ -1626,7 +1672,7 @@ class Select(Subqueryable):
dialect=None, dialect=None,
copy=True, copy=True,
**opts, **opts,
): ) -> Select:
""" """
Append to or set the JOIN expressions. Append to or set the JOIN expressions.
@ -1672,7 +1718,7 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery()) join.this.replace(join.this.subquery())
if join_type: if join_type:
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural: if natural:
join.set("natural", True) join.set("natural", True)
if side: if side:
@ -1681,12 +1727,12 @@ class Select(Subqueryable):
join.set("kind", kind.text) join.set("kind", kind.text)
if on: if on:
on = and_(*ensure_list(on), dialect=dialect, **opts) on = and_(*ensure_collection(on), dialect=dialect, **opts)
join.set("on", on) join.set("on", on)
if using: if using:
join = _apply_list_builder( join = _apply_list_builder(
*ensure_list(using), *ensure_collection(using),
instance=join, instance=join,
arg="using", arg="using",
append=append, append=append,
@ -1705,7 +1751,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def where(self, *expressions, append=True, dialect=None, copy=True, **opts): def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Append to or set the WHERE expressions. Append to or set the WHERE expressions.
@ -1737,7 +1783,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def having(self, *expressions, append=True, dialect=None, copy=True, **opts): def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
""" """
Append to or set the HAVING expressions. Append to or set the HAVING expressions.
@ -1769,7 +1815,7 @@ class Select(Subqueryable):
**opts, **opts,
) )
def distinct(self, distinct=True, copy=True): def distinct(self, distinct=True, copy=True) -> Select:
""" """
Set the OFFSET expression. Set the OFFSET expression.
@ -1788,7 +1834,7 @@ class Select(Subqueryable):
instance.set("distinct", Distinct() if distinct else None) instance.set("distinct", Distinct() if distinct else None)
return instance return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts): def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
""" """
Convert this expression to a CREATE TABLE AS statement. Convert this expression to a CREATE TABLE AS statement.
@ -1826,11 +1872,11 @@ class Select(Subqueryable):
) )
@property @property
def named_selects(self): def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name] return [e.alias_or_name for e in self.expressions if e.alias_or_name]
@property @property
def selects(self): def selects(self) -> t.List[Expression]:
return self.expressions return self.expressions
@ -1910,12 +1956,16 @@ class Parameter(Expression):
pass pass
class SessionParameter(Expression):
arg_types = {"this": True, "kind": False}
class Placeholder(Expression): class Placeholder(Expression):
arg_types = {"this": False} arg_types = {"this": False}
class Null(Condition): class Null(Condition):
arg_types = {} arg_types: t.Dict[str, t.Any] = {}
class Boolean(Condition): class Boolean(Condition):
@ -1936,6 +1986,7 @@ class DataType(Expression):
NVARCHAR = auto() NVARCHAR = auto()
TEXT = auto() TEXT = auto()
BINARY = auto() BINARY = auto()
VARBINARY = auto()
INT = auto() INT = auto()
TINYINT = auto() TINYINT = auto()
SMALLINT = auto() SMALLINT = auto()
@ -1975,7 +2026,7 @@ class DataType(Expression):
UNKNOWN = auto() # Sentinel value, useful for type annotation UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod @classmethod
def build(cls, dtype, **kwargs): def build(cls, dtype, **kwargs) -> DataType:
return DataType( return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs, **kwargs,
@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
pass pass
class NullSafeEQ(Binary, Predicate):
pass
class NullSafeNEQ(Binary, Predicate):
pass
class Distance(Binary):
pass
class Escape(Binary): class Escape(Binary):
pass pass
@ -2101,18 +2164,14 @@ class Is(Binary, Predicate):
pass pass
class Kwarg(Binary):
"""Kwarg in special functions like func(kwarg => y)."""
class Like(Binary, Predicate): class Like(Binary, Predicate):
pass pass
class SimilarTo(Binary, Predicate):
pass
class Distance(Binary):
pass
class LT(Binary, Predicate): class LT(Binary, Predicate):
pass pass
@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
pass pass
class SimilarTo(Binary, Predicate):
pass
class Sub(Binary): class Sub(Binary):
pass pass
@ -2189,7 +2252,13 @@ class Distinct(Expression):
class In(Predicate): class In(Predicate):
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} arg_types = {
"this": True,
"expressions": False,
"query": False,
"unnest": False,
"field": False,
}
class TimeUnit(Expression): class TimeUnit(Expression):
@ -2255,7 +2324,9 @@ class Func(Condition):
@classmethod @classmethod
def sql_names(cls): def sql_names(cls):
if cls is Func: if cls is Func:
raise NotImplementedError("SQL name is only supported by concrete function implementations") raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
if not hasattr(cls, "_sql_names"): if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)] cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names return cls._sql_names
@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False} arg_types = {"this": True, "expression": True, "unit": False}
class DateTrunc(Func, TimeUnit): class DateTrunc(Func):
arg_types = {"this": True, "unit": True, "zone": False} arg_types = {"this": True, "expression": True, "zone": False}
class DatetimeAdd(Func, TimeUnit): class DatetimeAdd(Func, TimeUnit):
@ -2791,6 +2862,10 @@ class Year(Func):
pass pass
class Use(Expression):
pass
def _norm_args(expression): def _norm_args(expression):
args = {} args = {}
@ -2822,7 +2897,7 @@ def maybe_parse(
dialect=None, dialect=None,
prefix=None, prefix=None,
**opts, **opts,
): ) -> t.Optional[Expression]:
"""Gracefully handle a possible string or expression. """Gracefully handle a possible string or expression.
Example: Example:
@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct) return Except(this=left, expression=right, distinct=distinct)
def select(*expressions, dialect=None, **opts): def select(*expressions, dialect=None, **opts) -> Select:
""" """
Initializes a syntax tree from one or multiple SELECT expressions. Initializes a syntax tree from one or multiple SELECT expressions.
@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
return Select().select(*expressions, dialect=dialect, **opts) return Select().select(*expressions, dialect=dialect, **opts)
def from_(*expressions, dialect=None, **opts): def from_(*expressions, dialect=None, **opts) -> Select:
""" """
Initializes a syntax tree from a FROM expression. Initializes a syntax tree from a FROM expression.
@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts) return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts): def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
""" """
Creates an update statement. Creates an update statement.
@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set( update.set(
"expressions", "expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], [
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
) )
if from_: if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
return update return update
def delete(table, where=None, dialect=None, **opts): def delete(table, where=None, dialect=None, **opts) -> Delete:
""" """
Builds a delete statement. Builds a delete statement.
@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
) )
def condition(expression, dialect=None, **opts): def condition(expression, dialect=None, **opts) -> Condition:
""" """
Initialize a logical condition expression. Initialize a logical condition expression.
@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
Returns: Returns:
Condition: the expression Condition: the expression
""" """
return maybe_parse( return maybe_parse( # type: ignore
expression, expression,
into=Condition, into=Condition,
dialect=dialect, dialect=dialect,
@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
) )
def and_(*expressions, dialect=None, **opts): def and_(*expressions, dialect=None, **opts) -> And:
""" """
Combine multiple conditions with an AND logical operator. Combine multiple conditions with an AND logical operator.
@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
return _combine(expressions, And, dialect, **opts) return _combine(expressions, And, dialect, **opts)
def or_(*expressions, dialect=None, **opts): def or_(*expressions, dialect=None, **opts) -> Or:
""" """
Combine multiple conditions with an OR logical operator. Combine multiple conditions with an OR logical operator.
@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
return _combine(expressions, Or, dialect, **opts) return _combine(expressions, Or, dialect, **opts)
def not_(expression, dialect=None, **opts): def not_(expression, dialect=None, **opts) -> Not:
""" """
Wrap a condition with a NOT operator. Wrap a condition with a NOT operator.
@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
return Not(this=_wrap_operator(this)) return Not(this=_wrap_operator(this))
def paren(expression): def paren(expression) -> Paren:
return Paren(this=expression) return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None): def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None: if alias is None:
return None return None
if isinstance(alias, Identifier): if isinstance(alias, Identifier):
@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
return identifier return identifier
def to_table(sql_path: str, **kwargs) -> Table: def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
""" """
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned. If a table is passed in then that table is returned.
Args: Args:
sql_path(str|Table): `[catalog].[schema].[table]` string sql_path: a `[catalog].[schema].[table]` string.
Returns: Returns:
Table: A table expression A table expression.
""" """
if sql_path is None or isinstance(sql_path, Table): if sql_path is None or isinstance(sql_path, Table):
return sql_path return sql_path
@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts) return Select().from_(expression, dialect=dialect, **opts)
def column(col, table=None, quoted=None): def column(col, table=None, quoted=None) -> Column:
""" """
Build a Column. Build a Column.
Args: Args:
@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
) )
def table_(table, db=None, catalog=None, quoted=None, alias=None): def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table. """Build a Table.
Args: Args:
@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
) )
def values(values, alias=None): def values(values, alias=None) -> Values:
"""Build VALUES statement. """Build VALUES statement.
Example: Example:
@ -3449,7 +3527,7 @@ def values(values, alias=None):
) )
def convert(value): def convert(value) -> Expression:
"""Convert a python value into an expression object. """Convert a python value into an expression object.
Raises an error if a conversion is not possible. Raises an error if a conversion is not possible.
@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
for cn in child_nodes: for cn in child_nodes:
if isinstance(cn, Expression): if isinstance(cn, Expression):
cns = ensure_list(fun(cn)) for child_node in ensure_collection(fun(cn)):
for child_node in cns:
new_child_nodes.append(child_node) new_child_nodes.append(child_node)
child_node.parent = expression child_node.parent = expression
child_node.arg_key = k child_node.arg_key = k
else: else:
new_child_nodes.append(cn) new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression): def column_table_names(expression):
@ -3529,7 +3606,7 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column))) return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
def table_name(table): def table_name(table) -> str:
"""Get the full name of a table as a string. """Get the full name of a table as a string.
Args: Args:
@ -3546,6 +3623,9 @@ def table_name(table):
table = maybe_parse(table, into=Table) table = maybe_parse(table, into=Table)
if not table:
raise ValueError(f"Cannot parse {table}")
return ".".join( return ".".join(
part part
for part in ( for part in (

View file

@ -1,4 +1,8 @@
from __future__ import annotations
import logging import logging
import re
import typing as t
from sqlglot import exp from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator: class Generator:
""" """
@ -47,8 +53,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true The default is on the smaller end because the length only represents a segment and not the true
line length. line length.
Default: 80 Default: 80
annotations: Whether or not to show annotations in the SQL when `pretty` is True. comments: Whether or not to preserve comments in the ouput SQL code.
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
Default: True Default: True
""" """
@ -65,14 +70,16 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name), exp.VolatilityProperty: lambda self, e: self.sql(e.name),
} }
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
# can override in dialects
CREATE_TRANSIENT = False CREATE_TRANSIENT = False
# whether or not null ordering is supported in order by
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
# Always do union distinct or union all
EXPLICIT_UNION = False EXPLICIT_UNION = False
# wrap derived values in parens, usually standard but spark doesn't support it
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True WRAP_DERIVED_VALUES = True
TYPE_MAPPING = { TYPE_MAPPING = {
@ -80,7 +87,7 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.NVARCHAR: "VARCHAR",
} }
TOKEN_MAPPING = {} TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">") STRUCT_DELIMITER = ("<", ">")
@ -96,6 +103,8 @@ class Generator:
exp.TableFormatProperty, exp.TableFormatProperty,
} }
WITH_SEPARATED_COMMENTS = (exp.Select,)
__slots__ = ( __slots__ = (
"time_mapping", "time_mapping",
"time_trie", "time_trie",
@ -122,7 +131,7 @@ class Generator:
"_escaped_quote_end", "_escaped_quote_end",
"_leading_comma", "_leading_comma",
"_max_text_width", "_max_text_width",
"_annotations", "_comments",
) )
def __init__( def __init__(
@ -148,7 +157,7 @@ class Generator:
max_unsupported=3, max_unsupported=3,
leading_comma=False, leading_comma=False,
max_text_width=80, max_text_width=80,
annotations=True, comments=True,
): ):
import sqlglot import sqlglot
@ -177,7 +186,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma self._leading_comma = leading_comma
self._max_text_width = max_text_width self._max_text_width = max_text_width
self._annotations = annotations self._comments = comments
def generate(self, expression): def generate(self, expression):
""" """
@ -204,7 +213,6 @@ class Generator:
return sql return sql
def unsupported(self, message): def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE: if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message) raise UnsupportedError(message)
self.unsupported_messages.append(message) self.unsupported_messages.append(message)
@ -215,9 +223,31 @@ class Generator:
def seg(self, sql, sep=" "): def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}" return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None
if not comment:
return sql
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
if not self.pretty:
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}"
def wrap(self, expression): def wrap(self, expression):
this_sql = self.indent( this_sql = self.indent(
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"), self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
level=1, level=1,
pad=0, pad=0,
) )
@ -251,7 +281,7 @@ class Generator:
for i, line in enumerate(lines) for i, line in enumerate(lines)
) )
def sql(self, expression, key=None): def sql(self, expression, key=None, comment=True):
if not expression: if not expression:
return "" return ""
@ -264,29 +294,24 @@ class Generator:
transform = self.TRANSFORMS.get(expression.__class__) transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform): if callable(transform):
return transform(self, expression) sql = transform(self, expression)
if transform: elif transform:
return transform sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
if not isinstance(expression, exp.Expression): if hasattr(self, exp_handler_name):
sql = getattr(self, exp_handler_name)(expression)
elif isinstance(expression, exp.Func):
sql = self.function_fallback_sql(expression)
elif isinstance(expression, exp.Property):
sql = self.property_sql(expression)
else:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql" return self.maybe_comment(sql, expression) if self._comments and comment else sql
if hasattr(self, exp_handler_name):
return getattr(self, exp_handler_name)(expression)
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
if self._annotations and self.pretty:
return f"{self.sql(expression, 'expression')} # {expression.name}"
return self.sql(expression, "expression")
def uncache_sql(self, expression): def uncache_sql(self, expression):
table = self.sql(expression, "this") table = self.sql(expression, "this")
@ -371,7 +396,9 @@ class Generator:
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else "" temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
replace = " OR REPLACE" if expression.args.get("replace") else "" replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else "" unique = " UNIQUE" if expression.args.get("unique") else ""
@ -434,7 +461,9 @@ class Generator:
def delete_sql(self, expression): def delete_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
using_sql = ( using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" f" USING {self.expressions(expression, 'using', sep=', USING ')}"
if expression.args.get("using")
else ""
) )
where_sql = self.sql(expression, "where") where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}" sql = f"DELETE FROM {this}{using_sql}{where_sql}"
@ -481,15 +510,18 @@ class Generator:
return f"{this} ON {table} {columns}" return f"{this} ON {table} {columns}"
def identifier_sql(self, expression): def identifier_sql(self, expression):
value = expression.name text = expression.name
value = value.lower() if self.normalize else value text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify: if expression.args.get("quoted") or self.identify:
return f"{self.identifier_start}{value}{self.identifier_end}" text = f"{self.identifier_start}{text}{self.identifier_end}"
return value return text
def partition_sql(self, expression): def partition_sql(self, expression):
keys = csv( keys = csv(
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")] *[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
for prop in expression.this
]
) )
return f"PARTITION({keys})" return f"PARTITION({keys})"
@ -504,9 +536,9 @@ class Generator:
elif p_class in self.ROOT_PROPERTIES: elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p) root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties( return self.root_properties(
exp.Properties(expressions=with_properties) exp.Properties(expressions=root_properties)
) ) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties): def root_properties(self, properties):
if properties.expressions: if properties.expressions:
@ -551,7 +583,9 @@ class Generator:
this = f"{this}{self.sql(expression, 'this')}" this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " " exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" partition_sql = (
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else "" sep = self.sep() if partition_sql else ""
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
@ -669,7 +703,9 @@ class Generator:
def group_sql(self, expression): def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression) group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = self.expressions(expression, key="cube", indent=False) cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False) rollup = self.expressions(expression, key="rollup", indent=False)
@ -711,10 +747,10 @@ class Generator:
this_sql = self.sql(expression, "this") this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}" return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
def lambda_sql(self, expression): def lambda_sql(self, expression, arrow_sep="->"):
args = self.expressions(expression, flat=True) args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
def lateral_sql(self, expression): def lateral_sql(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -748,7 +784,7 @@ class Generator:
if self._replace_backslash: if self._replace_backslash:
text = text.replace("\\", "\\\\") text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end) text = text.replace(self.quote_end, self._escaped_quote_end)
return f"{self.quote_start}{text}{self.quote_end}" text = f"{self.quote_start}{text}{self.quote_end}"
return text return text
def loaddata_sql(self, expression): def loaddata_sql(self, expression):
@ -796,13 +832,21 @@ class Generator:
sort_order = " DESC" if desc else "" sort_order = " DESC" if desc else ""
nulls_sort_change = "" nulls_sort_change = ""
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last): if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
nulls_sort_change = " NULLS FIRST" nulls_sort_change = " NULLS FIRST"
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last: elif (
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
nulls_sort_change = " NULLS LAST" nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect") self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
nulls_sort_change = "" nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@ -835,7 +879,7 @@ class Generator:
sql = self.query_modifiers( sql = self.query_modifiers(
expression, expression,
f"SELECT{hint}{distinct}{expressions}", f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "from"), self.sql(expression, "from", comment=False),
) )
return self.prepend_ctes(expression, sql) return self.prepend_ctes(expression, sql)
@ -858,6 +902,13 @@ class Generator:
def parameter_sql(self, expression): def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}" return f"@{self.sql(expression, 'this')}"
def sessionparameter_sql(self, expression):
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
def placeholder_sql(self, expression): def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?" return f":{expression.name}" if expression.name else "?"
@ -931,7 +982,10 @@ class Generator:
def window_spec_sql(self, expression): def window_spec_sql(self, expression):
kind = self.sql(expression, "kind") kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW" end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
return f"{kind} BETWEEN {start} AND {end}" return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression): def withingroup_sql(self, expression):
@ -1020,7 +1074,9 @@ class Generator:
return f"UNIQUE ({columns})" return f"UNIQUE ({columns})"
def if_sql(self, expression): def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))) return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
def in_sql(self, expression): def in_sql(self, expression):
query = expression.args.get("query") query = expression.args.get("query")
@ -1196,6 +1252,12 @@ class Generator:
def neq_sql(self, expression): def neq_sql(self, expression):
return self.binary(expression, "<>") return self.binary(expression, "<>")
def nullsafeeq_sql(self, expression):
return self.binary(expression, "IS NOT DISTINCT FROM")
def nullsafeneq_sql(self, expression):
return self.binary(expression, "IS DISTINCT FROM")
def or_sql(self, expression): def or_sql(self, expression):
return self.connector_sql(expression, "OR") return self.connector_sql(expression, "OR")
@ -1205,6 +1267,9 @@ class Generator:
def trycast_sql(self, expression): def trycast_sql(self, expression):
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def use_sql(self, expression):
return f"USE {self.sql(expression, 'this')}"
def binary(self, expression, op): def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
@ -1240,17 +1305,27 @@ class Generator:
if flat: if flat:
return sep.join(self.sql(e) for e in expressions) return sep.join(self.sql(e) for e in expressions)
sql = (self.sql(e) for e in expressions) num_sqls = len(expressions)
# the only time leading_comma changes the output is if pretty print is enabled
if self._leading_comma and self.pretty:
pad = " " * self.pad
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
else:
expressions = self.sep(sep).join(sql)
if indent: # These are calculated once in case we have the leading_comma / pretty option set, correspondingly
return self.indent(expressions, skip_first=False) pad = " " * self.pad
return expressions stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
def op_expressions(self, op, expression, flat=False): def op_expressions(self, op, expression, flat=False):
expressions_sql = self.expressions(expression, flat=flat) expressions_sql = self.expressions(expression, flat=flat)
@ -1264,7 +1339,9 @@ class Generator:
def set_operation(self, expression, op): def set_operation(self, expression, op):
this = self.sql(expression, "this") this = self.sql(expression, "this")
op = self.seg(op) op = self.seg(op)
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}") return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def token_sql(self, token_type): def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name) return self.TOKEN_MAPPING.get(token_type, token_type.name)
@ -1283,3 +1360,6 @@ class Generator:
this = self.sql(expression, "this") this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True) expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})" return f"{this}({expressions})"
def kwarg_sql(self, expression):
return self.binary(expression, "=>")

View file

@ -1,48 +1,125 @@
from __future__ import annotations
import inspect import inspect
import logging import logging
import re import re
import sys import sys
import typing as t import typing as t
from collections.abc import Collection
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from enum import Enum from enum import Enum
if t.TYPE_CHECKING:
from sqlglot.expressions import Expression, Table
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
class AutoName(Enum): class AutoName(Enum):
def _generate_next_value_(name, _start, _count, _last_values): """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
return name return name
def list_get(arr, index): def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try: try:
return arr[index] return seq[index]
except IndexError: except IndexError:
return None return None
@t.overload
def ensure_list(value: t.Collection[T]) -> t.List[T]:
...
@t.overload
def ensure_list(value: T) -> t.List[T]:
...
def ensure_list(value): def ensure_list(value):
"""
Ensures that a value is a list, otherwise casts or wraps it into one.
Args:
value: the value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
"""
if value is None: if value is None:
return [] return []
return value if isinstance(value, (list, tuple, set)) else [value] elif isinstance(value, (list, tuple)):
return list(value)
return [value]
def csv(*args, sep=", "): @t.overload
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
...
@t.overload
def ensure_collection(value: T) -> t.Collection[T]:
...
def ensure_collection(value):
"""
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
Args:
value: the value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
"""
if value is None:
return []
return (
value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
)
def csv(*args, sep: str = ", ") -> str:
"""
Formats any number of string arguments as CSV.
Args:
args: the string arguments to format.
sep: the argument separator.
Returns:
The arguments formatted as a CSV string.
"""
return sep.join(arg for arg in args if arg) return sep.join(arg for arg in args if arg)
def subclasses(module_name, classes, exclude=()): def subclasses(
module_name: str,
classes: t.Type | t.Tuple[t.Type, ...],
exclude: t.Type | t.Tuple[t.Type, ...] = (),
) -> t.List[t.Type]:
""" """
Returns a list of all subclasses for a specified class set, posibly excluding some of them. Returns all subclasses for a collection of classes, possibly excluding some of them.
Args: Args:
module_name (str): The name of the module to search for subclasses in. module_name: the name of the module to search for subclasses in.
classes (type|tuple[type]): Class(es) we want to find the subclasses of. classes: class(es) we want to find the subclasses of.
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list. exclude: class(es) we want to exclude from the returned list.
Returns: Returns:
A list of all the target subclasses. The target subclasses.
""" """
return [ return [
obj obj
@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
] ]
def apply_index_offset(expressions, offset): def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
"""
Applies an offset to a given integer literal expression.
Args:
expressions: the expression the offset will be applied to, wrapped in a list.
offset: the offset that will be applied.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
`expressions` argument contains more than one expressions, it's returned unaffected.
"""
if not offset or len(expressions) != 1: if not offset or len(expressions) != 1:
return expressions return expressions
@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
logger.warning("Applying array index offset (%s)", offset) logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset) expression.args["this"] = str(int(expression.args["this"]) + offset)
return [expression] return [expression]
return expressions return expressions
def camel_to_snake_case(name): def camel_to_snake_case(name: str) -> str:
"""Converts `name` from camelCase to snake_case and returns the result."""
return CAMEL_CASE_PATTERN.sub("_", name).upper() return CAMEL_CASE_PATTERN.sub("_", name).upper()
def while_changing(expression, func): def while_changing(
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
) -> E:
"""
Applies a transformation to a given expression until a fix point is reached.
Args:
expression: the expression to be transformed.
func: the transformation to be applied.
Returns:
The transformed expression.
"""
while True: while True:
start = hash(expression) start = hash(expression)
expression = func(expression) expression = func(expression)
@ -80,10 +182,19 @@ def while_changing(expression, func):
return expression return expression
def tsort(dag): def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
Args:
dag: the graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
"""
result = [] result = []
def visit(node, visited): def visit(node: T, visited: t.Set[T]) -> None:
if node in result: if node in result:
return return
if node in visited: if node in visited:
@ -103,10 +214,8 @@ def tsort(dag):
return result return result
def open_file(file_name): def open_file(file_name: str) -> t.TextIO:
""" """Open a file that may be compressed as gzip and return it in universal newline mode."""
Open a file that may be compressed as gzip and return in newline mode.
"""
with open(file_name, "rb") as f: with open(file_name, "rb") as f:
gzipped = f.read(2) == b"\x1f\x8b" gzipped = f.read(2) == b"\x1f\x8b"
@ -119,14 +228,14 @@ def open_file(file_name):
@contextmanager @contextmanager
def csv_reader(table): def csv_reader(table: Table) -> t.Any:
""" """
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args: Args:
table (exp.Table): A table expression with an anonymous function READ_CSV in it table: a `Table` expression with an anonymous function `READ_CSV` in it.
Returns: Yields:
A python csv reader. A python csv reader.
""" """
file, *args = table.this.expressions file, *args = table.this.expressions
@ -147,13 +256,16 @@ def csv_reader(table):
file.close() file.close()
def find_new_name(taken, base): def find_new_name(taken: t.Sequence[str], base: str) -> str:
""" """
Searches for a new name. Searches for a new name.
Args: Args:
taken (Sequence[str]): set of taken names taken: a collection of taken names.
base (str): base name to alter base: base name to alter.
Returns:
The new, available name.
""" """
if base not in taken: if base not in taken:
return base return base
@ -163,22 +275,26 @@ def find_new_name(taken, base):
while new in taken: while new in taken:
i += 1 i += 1
new = f"{base}_{i}" new = f"{base}_{i}"
return new return new
def object_to_dict(obj, **kwargs): def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]: def split_num_words(
value: str, sep: str, min_num_words: int, fill_from_start: bool = True
) -> t.List[t.Optional[str]]:
""" """
Perform a split on a value and return N words as a result with None used for words that don't exist. Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args: Args:
value: The value to be split value: the value to be split.
sep: The value to use to split on sep: the value to use to split on.
min_num_words: The minimum number of words that are going to be in the result min_num_words: the minimum number of words that are going to be in the result.
fill_from_start: Indicates that if None values should be inserted at the start or end of the list fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
Examples: Examples:
>>> split_num_words("db.table", ".", 3) >>> split_num_words("db.table", ".", 3)
@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
['db', 'table', None] ['db', 'table', None]
>>> split_num_words("db.table", ".", 1) >>> split_num_words("db.table", ".", 1)
['db', 'table'] ['db', 'table']
Returns:
The list of words returned by `split`, possibly augmented by a number of `None` values.
""" """
words = value.split(sep) words = value.split(sep)
if fill_from_start: if fill_from_start:
@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
def is_iterable(value: t.Any) -> bool: def is_iterable(value: t.Any) -> bool:
""" """
Checks if the value is an iterable but does not include strings and bytes Checks if the value is an iterable, excluding the types `str` and `bytes`.
Examples: Examples:
>>> is_iterable([1,2]) >>> is_iterable([1,2])
@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
False False
Args: Args:
value: The value to check if it is an interable value: the value to check if it is an iterable.
Returns: Bool indicating if it is an iterable Returns:
A `bool` value indicating if it is an iterable.
""" """
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes)) return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]: def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
""" """
Flattens a list that can contain both iterables and non-iterable elements Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type `str` and `bytes` are not regarded as iterables.
Examples: Examples:
>>> list(flatten([[1, 2], 3])) >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3] [1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3])) >>> list(flatten([1, 2, 3]))
[1, 2, 3] [1, 2, 3]
Args: Args:
values: The value to be flattened values: the value to be flattened.
Returns: Yields:
Yields non-iterable elements (not including str or byte as iterable) Non-iterable elements in `values`.
""" """
for value in values: for value in values:
if is_iterable(value): if is_iterable(value):

View file

@ -1,5 +1,5 @@
from sqlglot import exp from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema from sqlglot.schema import ensure_schema
@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), expr, exp.DataType.Type.DATETIME
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), ),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), expr, exp.DataType.Type.DATETIME
),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), expr, exp.DataType.Type.DATE
),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), expr, exp.DataType.Type.DOUBLE
),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BOOLEAN
),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.StrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), expr, exp.DataType.Type.VARCHAR
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), ),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.UnixToTime: lambda self, expr: self._annotate_with_type(
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), expr, exp.DataType.Type.TIMESTAMP
),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.VariancePop: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
} }
@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(), exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, exp.DataType.Type.NCHAR: {
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
exp.DataType.Type.CHAR: { exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR, exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR, exp.DataType.Type.VARCHAR,
@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(), exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, exp.DataType.Type.BIGINT: {
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.INT: { exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT, exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL, exp.DataType.Type.DECIMAL,
@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(), exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, exp.DataType.Type.TIMESTAMP: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATETIME: { exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPTZ,
@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression): def _annotate_args(self, expression):
for value in expression.args.values(): for value in expression.args.values():
for v in ensure_list(value): for v in ensure_collection(value):
self._maybe_annotate(v) self._maybe_annotate(v)
return expression return expression
@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type): elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) expression.type = exp.DataType.build(
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
)
else: else:
expression.type = exp.DataType.Type.BOOLEAN expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)): elif isinstance(expression, (exp.Condition, exp.Predicate)):
@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type): def _annotate_with_type(self, expression, target_type):
expression.type = target_type expression.type = target_type
return self._annotate_args(expression) return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args):
self._annotate_args(expression)
expressions = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
last_datatype = None
for expr in expressions:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
return expression

View file

@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
else: else:
on_clause_columns = set() on_clause_columns = set()
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) return any(
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
)
def _is_joined_on_all_unique_outputs(scope, join): def _is_joined_on_all_unique_outputs(scope, join):

View file

@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
# All table names are taken # All table names are taken
for scope in root.traverse(): for scope in root.traverse():
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)}) taken.update(
{
source.name: source
for _, source in scope.sources.items()
if isinstance(source, exp.Table)
}
)
# Map of Expression->alias # Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication. # Existing CTES in the root expression. We'll use this for deduplication.
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent) new_ctes.append(cte_scope.expression.parent)
# Now append the rest # Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes): for scope in itertools.chain(
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
):
for child_scope in scope.traverse(): for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken) new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte: if new_cte:

View file

@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
unmergable_window_columns = [ unmergable_window_columns = [
column column
for column in outer_scope.columns for column in outer_scope.columns
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) if column.find_ancestor(
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
)
] ]
window_expressions_in_unmergable = [ window_expressions_in_unmergable = [
column column
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and not ( and not (
isinstance(from_or_join, exp.From) isinstance(from_or_join, exp.From)
and inner_select.args.get("where") and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) and any(
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
) )
and not _is_a_window_expression_in_unmergable_operation() and not _is_a_window_expression_in_unmergable_operation()
) )
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
if table.alias_or_name == node_to_replace.alias_or_name: if table.alias_or_name == node_to_replace.alias_or_name:
table.set("this", exp.to_identifier(new_subquery.alias_or_name)) table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias) outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) outer_scope.add_source(
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
)
def _merge_joins(outer_scope, inner_scope, from_or_join): def _merge_joins(outer_scope, inner_scope, from_or_join):
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
inner_scope (sqlglot.optimizer.scope.Scope) inner_scope (sqlglot.optimizer.scope.Scope)
""" """
if ( if (
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]) any(
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
)
or len(outer_scope.selected_sources) != 1 or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
): ):

View file

@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
Returns: Returns:
int: difference int: difference
""" """
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1) return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf): def _predicate_lengths(expression, dnf):

View file

@ -68,4 +68,8 @@ def normalize(expression):
def other_table_names(join, exclude): def other_table_names(join, exclude):
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude] return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
# Find any additional rule parameters, beyond `expression` # Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs} rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs) expression = rule(expression, **rule_kwargs)
return expression return expression

View file

@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition)) condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True) cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like: if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count) pushdown_cnf(predicates, sources, scope_ref_count)
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column): for column in predicate.find_all(exp.Column):
if column.table == table: if column.table == table:
condition = column.find_ancestor(exp.Condition) condition = column.find_ancestor(exp.Condition)
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition: if predicate_condition:
conditions[table] = ( conditions[table] = (
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
) )
for name, node in nodes.items(): for name, node in nodes.items():
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1: elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions # We can't push down window expressions
has_window_expression = any(select for select in node.selects if select.find(exp.Window)) has_window_expression = any(
select for select in node.selects if select.find(exp.Window)
)
# we can't push down predicates to select statements if they are referenced in # we can't push down predicates to select statements if they are referenced in
# multiple places. # multiple places.
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: if (
not node.args.get("group")
and scope_ref_count[id(source)] < 2
and not has_window_expression
):
nodes[table] = node nodes[table] = node
return nodes return nodes
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column): def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases: if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name] return aliases[column.name].copy()
return column return column
return predicate.transform(_replace_alias) return predicate.transform(_replace_alias)

View file

@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
def _remove_indexed_selections(scope, indexes_to_remove): def _remove_indexed_selections(scope, indexes_to_remove):
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove] new_selections = [
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION) new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)

View file

@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
# Determine whether each reference in the order by clause is to a column or an alias. # Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered): for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column): for column in ordered.find_all(exp.Column):
if not column.table and column.parent is not ordered and column.name in resolver.all_columns: if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column) columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias. # Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having): for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column): for column in having.find_all(exp.Column):
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: if (
not column.table
and column.find_ancestor(exp.AggFunc)
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column) columns_missing_from_scope.append(column)
for column in columns_missing_from_scope: for column in columns_missing_from_scope:
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased""" """Ensure all output columns are aliased"""
new_selections = [] new_selections = []
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)): for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column): if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy # convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name) alias_ = alias(exp.column(""), alias=selection.name)
@ -343,14 +353,18 @@ class _Resolver:
(str) table name (str) table name
""" """
if self._unambiguous_columns is None: if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns()) self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name) return self._unambiguous_columns.get(column_name)
@property @property
def all_columns(self): def all_columns(self):
"""All available columns of all sources in this scope""" """All available columns of all sources in this scope"""
if self._all_columns is None: if self._all_columns is None:
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) self._all_columns = set(
column for columns in self._get_all_source_columns().values() for column in columns
)
return self._all_columns return self._all_columns
def get_source_columns(self, name, only_visible=False): def get_source_columns(self, name, only_visible=False):
@ -377,7 +391,9 @@ class _Resolver:
def _get_all_source_columns(self): def _get_all_source_columns(self):
if self._source_columns is None: if self._source_columns is None:
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources} self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns return self._source_columns
def _get_unambiguous_columns(self, source_columns): def _get_unambiguous_columns(self, source_columns):

View file

@ -226,7 +226,9 @@ class Scope:
self._ensure_collected() self._ensure_collected()
columns = self._raw_columns columns = self._raw_columns
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns] external_columns = [
column for scope in self.subquery_scopes for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions} named_outputs = {e.alias_or_name for e in self.expression.expressions}
@ -278,7 +280,11 @@ class Scope:
Returns: Returns:
dict[str, Scope]: Mapping of source alias to Scope dict[str, Scope]: Mapping of source alias to Scope
""" """
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte} return {
alias: scope
for alias, scope in self.sources.items()
if isinstance(scope, Scope) and scope.is_cte
}
@property @property
def selects(self): def selects(self):
@ -307,7 +313,9 @@ class Scope:
sources in the current scope. sources in the current scope.
""" """
if self._external_columns is None: if self._external_columns is None:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns return self._external_columns
@property @property

View file

@ -229,7 +229,9 @@ def simplify_literals(expression):
operands.append(a) operands.append(a)
if len(operands) < size: if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg): elif isinstance(expression, exp.Neg):
this = expression.this this = expression.this
if this.is_number: if this.is_number:
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
return TRUE if not_ else FALSE return TRUE if not_ else FALSE
if a == NULL: if a == NULL:
return FALSE if not_ else TRUE return FALSE if not_ else TRUE
elif isinstance(expression, exp.NullSafeEQ):
if a == b:
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
elif NULL in (a, b): elif NULL in (a, b):
return NULL return NULL
@ -357,7 +365,7 @@ def extract_date(cast):
def extract_interval(interval): def extract_interval(interval):
try: try:
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError: except ModuleNotFoundError:
return None return None

View file

@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
return return
if isinstance(predicate, exp.Binary): if isinstance(predicate, exp.Binary):
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else: else:
return return
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
else: else:
parent_predicate = _replace(parent_predicate, "TRUE") parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All): elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})") parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any): elif isinstance(parent_predicate, exp.Any):
if value.this in group_by: if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by: if key in group_by:
key.replace(nested) key.replace(nested)
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)") parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ): elif isinstance(predicate, exp.EQ):
parent_predicate = _replace( parent_predicate = _replace(
parent_predicate, parent_predicate,

View file

@ -1,9 +1,13 @@
from __future__ import annotations
import logging import logging
import typing as t
from sqlglot import exp from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors from sqlglot.errors import ErrorLevel, ParseError, concat_errors
from sqlglot.helper import apply_index_offset, ensure_list, list_get from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
@ -20,7 +24,15 @@ def parse_var_map(args):
) )
class Parser: class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
return klass
class Parser(metaclass=_Parser):
""" """
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
and produces a parsed syntax tree. and produces a parsed syntax tree.
@ -45,16 +57,16 @@ class Parser:
FUNCTIONS = { FUNCTIONS = {
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
"DATE_TO_DATE_STR": lambda args: exp.Cast( "DATE_TO_DATE_STR": lambda args: exp.Cast(
this=list_get(args, 0), this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT), to=exp.DataType(this=exp.DataType.Type.TEXT),
), ),
"TIME_TO_TIME_STR": lambda args: exp.Cast( "TIME_TO_TIME_STR": lambda args: exp.Cast(
this=list_get(args, 0), this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT), to=exp.DataType(this=exp.DataType.Type.TEXT),
), ),
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
this=exp.Cast( this=exp.Cast(
this=list_get(args, 0), this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT), to=exp.DataType(this=exp.DataType.Type.TEXT),
), ),
start=exp.Literal.number(1), start=exp.Literal.number(1),
@ -90,6 +102,7 @@ class Parser:
TokenType.NVARCHAR, TokenType.NVARCHAR,
TokenType.TEXT, TokenType.TEXT,
TokenType.BINARY, TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON, TokenType.JSON,
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.TIMESTAMP, TokenType.TIMESTAMP,
@ -243,6 +256,7 @@ class Parser:
EQUALITY = { EQUALITY = {
TokenType.EQ: exp.EQ, TokenType.EQ: exp.EQ,
TokenType.NEQ: exp.NEQ, TokenType.NEQ: exp.NEQ,
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
} }
COMPARISON = { COMPARISON = {
@ -298,6 +312,21 @@ class Parser:
TokenType.ANTI, TokenType.ANTI,
} }
LAMBDAS = {
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
this=self._parse_conjunction().transform(
self._replace_lambda, {node.name for node in expressions}
),
expressions=expressions,
),
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg,
this=exp.Var(this=expressions[0].name),
expression=self._parse_conjunction(),
),
}
COLUMN_OPERATORS = { COLUMN_OPERATORS = {
TokenType.DOT: None, TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression( TokenType.DCOLON: lambda self, this, to: self.expression(
@ -362,20 +391,30 @@ class Parser:
TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(), TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(),
} }
PRIMARY_PARSERS = { PRIMARY_PARSERS = {
TokenType.STRING: lambda _, token: exp.Literal.string(token.text), TokenType.STRING: lambda self, token: self.expression(
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), exp.Literal, this=token.text, is_string=True
TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}), ),
TokenType.NULL: lambda *_: exp.Null(), TokenType.NUMBER: lambda self, token: self.expression(
TokenType.TRUE: lambda *_: exp.Boolean(this=True), exp.Literal, this=token.text, is_string=False
TokenType.FALSE: lambda *_: exp.Boolean(this=False), ),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.STAR: lambda self, _: self.expression(
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), ),
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text), TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.PARAMETER: lambda self, _: self.expression(
exp.Parameter, this=self._parse_var() or self._parse_primary()
),
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
} }
RANGE_PARSERS = { RANGE_PARSERS = {
@ -411,16 +450,24 @@ class Parser:
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
exp.TableFormatProperty
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(), TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.DETERMINISTIC: lambda self: self.expression( TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
), ),
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")), TokenType.IMMUTABLE: lambda self: self.expression(
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")), exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")), ),
TokenType.STABLE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
@ -450,7 +497,8 @@ class Parser:
"group": lambda self: self._parse_group(), "group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(), "having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(), "qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True), "window": lambda self: self._match(TokenType.WINDOW)
and self._parse_window(self._parse_id_var(), alias=True),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@ -459,6 +507,9 @@ class Parser:
"offset": lambda self: self._parse_offset(), "offset": lambda self: self._parse_offset(),
} }
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
SET_PARSERS: t.Dict[str, t.Callable] = {}
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = { CREATABLES = {
@ -488,7 +539,9 @@ class Parser:
"_curr", "_curr",
"_next", "_next",
"_prev", "_prev",
"_greedy_subqueries", "_prev_comment",
"_show_trie",
"_set_trie",
) )
def __init__( def __init__(
@ -519,7 +572,7 @@ class Parser:
self._curr = None self._curr = None
self._next = None self._next = None
self._prev = None self._prev = None
self._greedy_subqueries = False self._prev_comment = None
def parse(self, raw_tokens, sql=None): def parse(self, raw_tokens, sql=None):
""" """
@ -533,10 +586,12 @@ class Parser:
Returns Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`). the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
""" """
return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql) return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
)
def parse_into(self, expression_types, raw_tokens, sql=None): def parse_into(self, expression_types, raw_tokens, sql=None):
for expression_type in ensure_list(expression_types): for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type) parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser: if not parser:
raise TypeError(f"No parser registered for {expression_type}") raise TypeError(f"No parser registered for {expression_type}")
@ -597,6 +652,9 @@ class Parser:
def expression(self, exp_class, **kwargs): def expression(self, exp_class, **kwargs):
instance = exp_class(**kwargs) instance = exp_class(**kwargs)
if self._prev_comment:
instance.comment = self._prev_comment
self._prev_comment = None
self.validate_expression(instance) self.validate_expression(instance)
return instance return instance
@ -633,14 +691,16 @@ class Parser:
return index return index
def _get_token(self, index):
return list_get(self._tokens, index)
def _advance(self, times=1): def _advance(self, times=1):
self._index += times self._index += times
self._curr = self._get_token(self._index) self._curr = seq_get(self._tokens, self._index)
self._next = self._get_token(self._index + 1) self._next = seq_get(self._tokens, self._index + 1)
self._prev = self._get_token(self._index - 1) if self._index > 0 else None if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment
else:
self._prev = None
self._prev_comment = None
def _retreat(self, index): def _retreat(self, index):
self._advance(index - self._index) self._advance(index - self._index)
@ -661,6 +721,7 @@ class Parser:
expression = self._parse_expression() expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select() expression = self._parse_set_operations(expression) if expression else self._parse_select()
self._parse_query_modifiers(expression) self._parse_query_modifiers(expression)
return expression return expression
@ -682,7 +743,11 @@ class Parser:
) )
def _parse_exists(self, not_=False): def _parse_exists(self, not_=False):
return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) return (
self._match(TokenType.IF)
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
def _parse_create(self): def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@ -931,7 +996,9 @@ class Parser:
return self.expression( return self.expression(
exp.Delete, exp.Delete,
this=self._parse_table(schema=True), this=self._parse_table(schema=True),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)), using=self._parse_csv(
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
),
where=self._parse_where(), where=self._parse_where(),
) )
@ -983,11 +1050,13 @@ class Parser:
return None return None
def parse_values(): def parse_values():
k = self._parse_var() key = self._parse_var()
value = None
if self._match(TokenType.EQ): if self._match(TokenType.EQ):
v = self._parse_string() value = self._parse_string()
return (k, v)
return (k, None) return exp.Property(this=key, value=value)
self._match_l_paren() self._match_l_paren()
values = self._parse_csv(parse_values) values = self._parse_csv(parse_values)
@ -1019,6 +1088,8 @@ class Parser:
self.raise_error(f"{this.key} does not support CTE") self.raise_error(f"{this.key} does not support CTE")
this = cte this = cte
elif self._match(TokenType.SELECT): elif self._match(TokenType.SELECT):
comment = self._prev_comment
hint = self._parse_hint() hint = self._parse_hint()
all_ = self._match(TokenType.ALL) all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT) distinct = self._match(TokenType.DISTINCT)
@ -1033,7 +1104,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True) limit = self._parse_limit(top=True)
expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression())) expressions = self._parse_csv(self._parse_expression)
this = self.expression( this = self.expression(
exp.Select, exp.Select,
@ -1042,6 +1113,7 @@ class Parser:
expressions=expressions, expressions=expressions,
limit=limit, limit=limit,
) )
this.comment = comment
from_ = self._parse_from() from_ = self._parse_from()
if from_: if from_:
this.set("from", from_) this.set("from", from_)
@ -1072,8 +1144,10 @@ class Parser:
while True: while True:
expressions.append(self._parse_cte()) expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA): if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
break break
else:
self._match(TokenType.WITH)
return self.expression( return self.expression(
exp.With, exp.With,
@ -1111,11 +1185,7 @@ class Parser:
if not alias and not columns: if not alias and not columns:
return None return None
return self.expression( return self.expression(exp.TableAlias, this=alias, columns=columns)
exp.TableAlias,
this=alias,
columns=columns,
)
def _parse_subquery(self, this): def _parse_subquery(self, this):
return self.expression( return self.expression(
@ -1150,12 +1220,6 @@ class Parser:
if expression: if expression:
this.set(key, expression) this.set(key, expression)
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
return expression
def _parse_hint(self): def _parse_hint(self):
if self._match(TokenType.HINT): if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function) hints = self._parse_csv(self._parse_function)
@ -1295,7 +1359,9 @@ class Parser:
if not table: if not table:
self.raise_error("Expected table name") self.raise_error("Expected table name")
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()) this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
)
if schema: if schema:
return self._parse_schema(this=this) return self._parse_schema(this=this)
@ -1500,7 +1566,9 @@ class Parser:
if not skip_order_token and not self._match(TokenType.ORDER_BY): if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this return this
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) return self.expression(
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
def _parse_sort(self, token_type, exp_class): def _parse_sort(self, token_type, exp_class):
if not self._match(token_type): if not self._match(token_type):
@ -1521,7 +1589,8 @@ class Parser:
if ( if (
not explicitly_null_ordered not explicitly_null_ordered
and ( and (
(asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small") (asc and self.null_ordering == "nulls_are_small")
or (desc and self.null_ordering != "nulls_are_small")
) )
and self.null_ordering != "nulls_are_last" and self.null_ordering != "nulls_are_last"
): ):
@ -1606,6 +1675,9 @@ class Parser:
def _parse_is(self, this): def _parse_is(self, this):
negate = self._match(TokenType.NOT) negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression( this = self.expression(
exp.Is, exp.Is,
this=this, this=this,
@ -1653,9 +1725,13 @@ class Parser:
expression=self._parse_term(), expression=self._parse_term(),
) )
elif self._match_pair(TokenType.LT, TokenType.LT): elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term()) this = self.expression(
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
)
elif self._match_pair(TokenType.GT, TokenType.GT): elif self._match_pair(TokenType.GT, TokenType.GT):
this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term()) this = self.expression(
exp.BitwiseRightShift, this=this, expression=self._parse_term()
)
else: else:
break break
@ -1685,7 +1761,7 @@ class Parser:
) )
index = self._index index = self._index
type_token = self._parse_types() type_token = self._parse_types(check_func=True)
this = self._parse_column() this = self._parse_column()
if type_token: if type_token:
@ -1698,7 +1774,7 @@ class Parser:
return this return this
def _parse_types(self): def _parse_types(self, check_func=False):
index = self._index index = self._index
if not self._match_set(self.TYPE_TOKENS): if not self._match_set(self.TYPE_TOKENS):
@ -1708,10 +1784,13 @@ class Parser:
nested = type_token in self.NESTED_TYPE_TOKENS nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT is_struct = type_token == TokenType.STRUCT
expressions = None expressions = None
maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType( return exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value)],
nested=True,
) )
if self._match(TokenType.L_BRACKET): if self._match(TokenType.L_BRACKET):
@ -1731,6 +1810,7 @@ class Parser:
return None return None
self._match_r_paren() self._match_r_paren()
maybe_func = True
if nested and self._match(TokenType.LT): if nested and self._match(TokenType.LT):
if is_struct: if is_struct:
@ -1741,26 +1821,47 @@ class Parser:
if not self._match(TokenType.GT): if not self._match(TokenType.GT):
self.raise_error("Expecting >") self.raise_error("Expecting >")
value = None
if type_token in self.TIMESTAMPS: if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
if tz: value = exp.DataType(
return exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ, this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions, expressions=expressions,
) )
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ elif (
if ltz: self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
return exp.DataType( ):
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ, this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions, expressions=expressions,
) )
self._match(TokenType.WITHOUT_TIME_ZONE) elif self._match(TokenType.WITHOUT_TIME_ZONE):
value = exp.DataType(
return exp.DataType(
this=exp.DataType.Type.TIMESTAMP, this=exp.DataType.Type.TIMESTAMP,
expressions=expressions, expressions=expressions,
) )
maybe_func = maybe_func and value is None
if value is None:
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
if maybe_func and check_func:
index2 = self._index
peek = self._parse_string()
if not peek:
self._retreat(index)
return None
self._retreat(index2)
if value:
return value
return exp.DataType( return exp.DataType(
this=exp.DataType.Type[type_token.value.upper()], this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions, expressions=expressions,
@ -1826,22 +1927,29 @@ class Parser:
return exp.Literal.number(f"0.{self._prev.text}") return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
comment = self._prev_comment
query = self._parse_select() query = self._parse_select()
if query: if query:
expressions = [query] expressions = [query]
else: else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True)) expressions = self._parse_csv(
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
)
this = list_get(expressions, 0) this = seq_get(expressions, 0)
self._parse_query_modifiers(this) self._parse_query_modifiers(this)
self._match_r_paren() self._match_r_paren()
if isinstance(this, exp.Subqueryable): if isinstance(this, exp.Subqueryable):
return self._parse_set_operations(self._parse_subquery(this)) this = self._parse_set_operations(self._parse_subquery(this))
if len(expressions) > 1: elif len(expressions) > 1:
return self.expression(exp.Tuple, expressions=expressions) this = self.expression(exp.Tuple, expressions=expressions)
return self.expression(exp.Paren, this=this) else:
this = self.expression(exp.Paren, this=this)
if comment:
this.comment = comment
return this
return None return None
@ -1894,7 +2002,8 @@ class Parser:
self.validate_expression(this, args) self.validate_expression(this, args)
else: else:
this = self.expression(exp.Anonymous, this=this, expressions=args) this = self.expression(exp.Anonymous, this=this, expressions=args)
self._match_r_paren()
self._match_r_paren(this)
return self._parse_window(this) return self._parse_window(this)
def _parse_user_defined_function(self): def _parse_user_defined_function(self):
@ -1920,6 +2029,18 @@ class Parser:
return self.expression(exp.Identifier, this=token.text) return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(
exp.SessionParameter,
this=this,
kind=kind,
)
def _parse_udf_kwarg(self): def _parse_udf_kwarg(self):
this = self._parse_id_var() this = self._parse_id_var()
kind = self._parse_types() kind = self._parse_types()
@ -1938,11 +2059,15 @@ class Parser:
else: else:
expressions = [self._parse_id_var()] expressions = [self._parse_id_var()]
if not self._match(TokenType.ARROW): if self._match_set(self.LAMBDAS):
return self.LAMBDAS[self._prev.token_type](self, expressions)
self._retreat(index) self._retreat(index)
if self._match(TokenType.DISTINCT): if self._match(TokenType.DISTINCT):
this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)) this = self.expression(
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else: else:
this = self._parse_conjunction() this = self._parse_conjunction()
@ -1953,20 +2078,15 @@ class Parser:
return self._parse_alias(self._parse_limit(self._parse_order(this))) return self._parse_alias(self._parse_limit(self._parse_order(this)))
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
return self.expression(
exp.Lambda,
this=conjunction,
expressions=expressions,
)
def _parse_schema(self, this=None): def _parse_schema(self, this=None):
index = self._index index = self._index
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
self._retreat(index) self._retreat(index)
return this return this
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))) args = self._parse_csv(
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
)
self._match_r_paren() self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args) return self.expression(exp.Schema, this=this, expressions=args)
@ -2104,6 +2224,7 @@ class Parser:
if not self._match(TokenType.R_BRACKET): if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]") self.raise_error("Expected ]")
this.comment = self._prev_comment
return self._parse_bracket(this) return self._parse_bracket(this)
def _parse_case(self): def _parse_case(self):
@ -2124,7 +2245,9 @@ class Parser:
if not self._match(TokenType.END): if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev) self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default)) return self._parse_window(
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
)
def _parse_if(self): def _parse_if(self):
if self._match(TokenType.L_PAREN): if self._match(TokenType.L_PAREN):
@ -2331,7 +2454,9 @@ class Parser:
self._match(TokenType.BETWEEN) self._match(TokenType.BETWEEN)
return { return {
"value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text) "value": (
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
)
or self._parse_bitwise(), or self._parse_bitwise(),
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
} }
@ -2348,7 +2473,7 @@ class Parser:
this=this, this=this,
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
) )
self._match_r_paren() self._match_r_paren(aliases)
return aliases return aliases
alias = self._parse_id_var(any_token) alias = self._parse_id_var(any_token)
@ -2365,28 +2490,29 @@ class Parser:
return identifier return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) self._advance()
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False) return None
return exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self): def _parse_string(self):
if self._match(TokenType.STRING): if self._match(TokenType.STRING):
return exp.Literal.string(self._prev.text) return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder() return self._parse_placeholder()
def _parse_number(self): def _parse_number(self):
if self._match(TokenType.NUMBER): if self._match(TokenType.NUMBER):
return exp.Literal.number(self._prev.text) return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
return self._parse_placeholder() return self._parse_placeholder()
def _parse_identifier(self): def _parse_identifier(self):
if self._match(TokenType.IDENTIFIER): if self._match(TokenType.IDENTIFIER):
return exp.Identifier(this=self._prev.text, quoted=True) return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder() return self._parse_placeholder()
def _parse_var(self): def _parse_var(self):
if self._match(TokenType.VAR): if self._match(TokenType.VAR):
return exp.Var(this=self._prev.text) return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder() return self._parse_placeholder()
def _parse_var_or_string(self): def _parse_var_or_string(self):
@ -2394,27 +2520,27 @@ class Parser:
def _parse_null(self): def _parse_null(self):
if self._match(TokenType.NULL): if self._match(TokenType.NULL):
return exp.Null() return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None return None
def _parse_boolean(self): def _parse_boolean(self):
if self._match(TokenType.TRUE): if self._match(TokenType.TRUE):
return exp.Boolean(this=True) return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE): if self._match(TokenType.FALSE):
return exp.Boolean(this=False) return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None return None
def _parse_star(self): def _parse_star(self):
if self._match(TokenType.STAR): if self._match(TokenType.STAR):
return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}) return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None return None
def _parse_placeholder(self): def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER): if self._match(TokenType.PLACEHOLDER):
return exp.Placeholder() return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON): elif self._match(TokenType.COLON):
self._advance() self._advance()
return exp.Placeholder(this=self._prev.text) return self.expression(exp.Placeholder, this=self._prev.text)
return None return None
def _parse_except(self): def _parse_except(self):
@ -2432,22 +2558,27 @@ class Parser:
self._match_r_paren() self._match_r_paren()
return columns return columns
def _parse_csv(self, parse): def _parse_csv(self, parse_method):
parse_result = parse() parse_result = parse_method()
items = [parse_result] if parse_result is not None else [] items = [parse_result] if parse_result is not None else []
while self._match(TokenType.COMMA): while self._match(TokenType.COMMA):
parse_result = parse() if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
parse_result = parse_method()
if parse_result is not None: if parse_result is not None:
items.append(parse_result) items.append(parse_result)
return items return items
def _parse_tokens(self, parse, expressions): def _parse_tokens(self, parse_method, expressions):
this = parse() this = parse_method()
while self._match_set(expressions): while self._match_set(expressions):
this = self.expression(expressions[self._prev.token_type], this=this, expression=parse()) this = self.expression(
expressions[self._prev.token_type], this=this, expression=parse_method()
)
return this return this
@ -2460,6 +2591,47 @@ class Parser:
def _parse_select_or_expression(self): def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression() return self._parse_select() or self._parse_expression()
def _parse_use(self):
return self.expression(exp.Use, this=self._parse_id_var())
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
def _default_parse_set_item(self):
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)
def _parse_set_item(self):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()
def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
def _find_parser(self, parsers, trie):
index = self._index
this = []
while True:
# The current token might be multiple words
curr = self._curr.text.upper()
key = curr.split(" ")
this.append(curr)
self._advance()
result, trie = in_trie(trie, key)
if result == 0:
break
if result == 2:
subparser = parsers[" ".join(this)]
return subparser
self._retreat(index)
return None
def _match(self, token_type): def _match(self, token_type):
if not self._curr: if not self._curr:
return None return None
@ -2491,13 +2663,17 @@ class Parser:
return None return None
def _match_l_paren(self): def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN): if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (") self.raise_error("Expecting (")
if expression and self._prev_comment:
expression.comment = self._prev_comment
def _match_r_paren(self): def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN): if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )") self.raise_error("Expecting )")
if expression and self._prev_comment:
expression.comment = self._prev_comment
def _match_text(self, *texts): def _match_text(self, *texts):
index = self._index index = self._index

View file

@ -72,7 +72,9 @@ class Step:
if from_: if from_:
from_ = from_.expressions from_ = from_.expressions
if len(from_) > 1: if len(from_) > 1:
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer") raise UnsupportedError(
"Multi-from statements are unsupported. Run it through the optimizer"
)
step = Scan.from_expression(from_[0], ctes) step = Scan.from_expression(from_[0], ctes)
else: else:
@ -102,7 +104,7 @@ class Step:
continue continue
if operand not in operands: if operand not in operands:
operands[operand] = f"_a_{next(sequence)}" operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], step.name, quoted=True)) operand.replace(exp.column(operands[operand], quoted=True))
else: else:
projections.append(e) projections.append(e)
@ -117,9 +119,11 @@ class Step:
aggregate = Aggregate() aggregate = Aggregate()
aggregate.source = step.name aggregate.source = step.name
aggregate.name = step.name aggregate.name = step.name
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations aggregate.aggregations = aggregations
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions] aggregate.group = group.expressions
aggregate.add_dependency(step) aggregate.add_dependency(step)
step = aggregate step = aggregate
@ -136,9 +140,6 @@ class Step:
sort.key = order.expressions sort.key = order.expressions
sort.add_dependency(step) sort.add_dependency(step)
step = sort step = sort
for k in sort.key + projections:
for column in k.find_all(exp.Column):
column.set("table", exp.to_identifier(step.name, quoted=True))
step.projections = projections step.projections = projections
@ -203,7 +204,9 @@ class Scan(Step):
alias_ = expression.alias alias_ = expression.alias
if not alias_: if not alias_:
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer") raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
if isinstance(expression, exp.Subquery): if isinstance(expression, exp.Subquery):
table = expression.this table = expression.this

0
sqlglot/py.typed Normal file
View file

View file

@ -1,44 +1,60 @@
from __future__ import annotations
import abc import abc
import typing as t
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.errors import OptimizeError from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader from sqlglot.helper import csv_reader
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
TABLE_ARGS = ("this", "db", "catalog")
class Schema(abc.ABC): class Schema(abc.ABC):
"""Abstract base class for database schemas""" """Abstract base class for database schemas"""
@abc.abstractmethod @abc.abstractmethod
def add_table(self, table, column_mapping=None): def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
""" """
Register or update a table. Some implementing classes may require column information to also be provided Register or update a table. Some implementing classes may require column information to also be provided.
Args: Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table table: table expression instance or string representing the table.
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table column_mapping: a column mapping that describes the structure of the table.
""" """
@abc.abstractmethod @abc.abstractmethod
def column_names(self, table, only_visible=False): def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
""" """
Get the column names for a table. Get the column names for a table.
Args: Args:
table (sqlglot.expressions.Table): Table expression instance table: the `Table` expression instance.
only_visible (bool): Whether to include invisible columns only_visible: whether to include invisible columns.
Returns: Returns:
list[str]: list of column names The list of column names.
""" """
@abc.abstractmethod @abc.abstractmethod
def get_column_type(self, table, column): def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
""" """
Get the exp.DataType type of a column in the schema. Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
Args: Args:
table (sqlglot.expressions.Table): The source table. table: the source table.
column (sqlglot.expressions.Column): The target column. column: the target column.
Returns: Returns:
sqlglot.expressions.DataType.Type: The resulting column type. The resulting column type.
""" """
@ -60,132 +76,179 @@ class MappingSchema(Schema):
dialect (str): The dialect to be used for custom type mappings. dialect (str): The dialect to be used for custom type mappings.
""" """
def __init__(self, schema=None, visible=None, dialect=None): def __init__(
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
self.schema = schema or {} self.schema = schema or {}
self.visible = visible self.visible = visible or {}
self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect self.dialect = dialect
self._type_mapping_cache = {} self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
self.supported_table_args = [] self._supported_table_args: t.Tuple[str, ...] = tuple()
self.forbidden_table_args = set()
if self.schema:
self._initialize_supported_args()
@classmethod @classmethod
def from_mapping_schema(cls, mapping_schema): def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema( return MappingSchema(
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect schema=mapping_schema.schema,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
) )
def copy(self, **kwargs): def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) return MappingSchema(
**{ # type: ignore
"schema": self.schema.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
def add_table(self, table, column_mapping=None): @property
def supported_table_args(self):
if not self._supported_table_args and self.schema:
depth = _dict_depth(self.schema)
if not depth or depth == 1: # {}
self._supported_table_args = tuple()
elif 2 <= depth <= 4:
self._supported_table_args = TABLE_ARGS[: depth - 1]
else:
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
return self._supported_table_args
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
""" """
Register or update a table. Updates are only performed if a new column mapping is provided. Register or update a table. Updates are only performed if a new column mapping is provided.
Args: Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table table: the `Table` expression instance or string representing the table.
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table column_mapping: a column mapping that describes the structure of the table.
""" """
table = exp.to_table(table) table_ = self._ensure_table(table)
self._validate_table(table)
column_mapping = ensure_column_mapping(column_mapping) column_mapping = ensure_column_mapping(column_mapping)
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] schema = self.find_schema(table_, raise_on_missing=False)
existing_column_mapping = _nested_get(
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False if schema and not column_mapping:
)
if existing_column_mapping and not column_mapping:
return return
_nested_set( _nested_set(
self.schema, self.schema,
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], list(reversed(self.table_parts(table_))),
column_mapping, column_mapping,
) )
self._initialize_supported_args() self.schema_trie = self._build_trie(self.schema)
def _get_table_args_from_table(self, table): def _ensure_table(self, table: exp.Table | str) -> exp.Table:
if table.args.get("catalog") is not None: table_ = exp.to_table(table)
return "catalog", "db", "this"
if table.args.get("db") is not None:
return "db", "this"
return ("this",)
def _validate_table(self, table): if not table_:
if not self.supported_table_args and isinstance(table, exp.Table): raise SchemaError(f"Not a valid table '{table}'")
return
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
for expected in self.supported_table_args:
if not table.text(expected):
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
def column_names(self, table, only_visible=False): return table_
table = exp.to_table(table)
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
args = tuple(table.text(p) for p in self.supported_table_args) def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
for forbidden in self.forbidden_table_args: def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
if table.text(forbidden): table_ = self._ensure_table(table)
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
if not isinstance(table_.this, exp.Identifier):
return fs_get(table) # type: ignore
schema = self.find_schema(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible: if not only_visible or not self.visible:
return columns return list(schema)
visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in columns if col in visible] return [col for col in schema if col in visible] # type: ignore
def get_column_type(self, table, column): def find_schema(
try: self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
schema_type = self.schema.get(table.name, {}).get(column.name).upper() ) -> t.Optional[t.Dict[str, str]]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find schema for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find_schema(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type) return self._convert_type(schema_type)
except: raise SchemaError(f"Could not convert table '{table}'")
raise OptimizeError(f"Failed to get type for column {column.sql()}")
def _convert_type(self, schema_type): def _convert_type(self, schema_type: str) -> exp.DataType.Type:
""" """
Convert a type represented as a string to the corresponding exp.DataType.Type object. Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
Args: Args:
schema_type (str): The type we want to convert. schema_type: the type we want to convert.
Returns: Returns:
sqlglot.expressions.DataType.Type: The resulting expression type. The resulting expression type.
""" """
if schema_type not in self._type_mapping_cache: if schema_type not in self._type_mapping_cache:
try: try:
self._type_mapping_cache[schema_type] = exp.maybe_parse( expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
schema_type, into=exp.DataType, dialect=self.dialect if expression is None:
).this raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression.this
except AttributeError: except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}") raise SchemaError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type] return self._type_mapping_cache[schema_type]
def _initialize_supported_args(self): def _build_trie(self, schema: t.Dict):
if not self.supported_table_args: return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
depth = _dict_depth(self.schema)
all_args = ["this", "db", "catalog"] def _nested_get(
if not depth or depth == 1: # {} self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
self.supported_table_args = [] ) -> t.Optional[t.Any]:
elif 2 <= depth <= 4: return _nested_get(
self.supported_table_args = tuple(reversed(all_args[: depth - 1])) d or self.schema,
else: *zip(self.supported_table_args, reversed(parts)),
raise OptimizeError(f"Invalid schema shape. Depth: {depth}") raise_on_missing=raise_on_missing,
)
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def ensure_schema(schema): def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema): if isinstance(schema, Schema):
return schema return schema
return MappingSchema(schema) return MappingSchema(schema)
def ensure_column_mapping(mapping): def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
if isinstance(mapping, dict): if isinstance(mapping, dict):
return mapping return mapping
elif isinstance(mapping, str): elif isinstance(mapping, str):
@ -196,7 +259,7 @@ def ensure_column_mapping(mapping):
} }
# Check if mapping looks like a DataFrame StructType # Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"): elif hasattr(mapping, "simpleString"):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
elif isinstance(mapping, list): elif isinstance(mapping, list):
return {x.strip(): None for x in mapping} return {x.strip(): None for x in mapping}
elif mapping is None: elif mapping is None:
@ -204,7 +267,20 @@ def ensure_column_mapping(mapping):
raise ValueError(f"Invalid mapping provided: {type(mapping)}") raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def fs_get(table): def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
tables = []
keys = keys or []
depth = _dict_depth(schema)
for k, v in schema.items():
if depth >= 3:
tables.extend(flatten_schema(v, keys + [k]))
elif depth == 2:
tables.append(keys + [k])
return tables
def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name name = table.this.name
if name.upper() == "READ_CSV": if name.upper() == "READ_CSV":
@ -214,21 +290,23 @@ def fs_get(table):
raise ValueError(f"Cannot read schema for {table}") raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path, raise_on_missing=True): def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
""" """
Get a value for a nested dictionary. Get a value for a nested dictionary.
Args: Args:
d (dict): dictionary d: the dictionary to search.
*path (tuple[str, str]): tuples of (name, key) *path: tuples of (name, key), where:
`key` is the key in the dictionary to get. `key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found. `name` is a string to use in the error if `key` isn't found.
Returns: Returns:
The value or None if it doesn't exist The value or None if it doesn't exist.
""" """
for name, key in path: for name, key in path:
d = d.get(key) d = d.get(key) # type: ignore
if d is None: if d is None:
if raise_on_missing: if raise_on_missing:
name = "table" if name == "this" else name name = "table" if name == "this" else name
@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True):
return d return d
def _nested_set(d, keys, value): def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
""" """
In-place set a value for a nested dictionary In-place set a value for a nested dictionary
Ex: Example:
>>> _nested_set({}, ["top_key", "second_key"], "value") >>> _nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}} {'top_key': {'second_key': 'value'}}
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}} {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
d (dict): dictionary Args:
keys (Iterable[str]): ordered iterable of keys that makeup path to value d: dictionary to update.
value (Any): The value to set in the dictionary for the given key path keys: the keys that makeup the path to `value`.
value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.
""" """
if not keys: if not keys:
return return d
if len(keys) == 1: if len(keys) == 1:
d[keys[0]] = value d[keys[0]] = value
return return d
subd = d subd = d
for key in keys[:-1]: for key in keys[:-1]:
if key not in subd: if key not in subd:
subd = subd.setdefault(key, {}) subd = subd.setdefault(key, {})
else: else:
subd = subd[key] subd = subd[key]
subd[keys[-1]] = value subd[keys[-1]] = value
return d return d
def _dict_depth(d): def _dict_depth(d: t.Dict) -> int:
""" """
Get the nesting depth of a dictionary. Get the nesting depth of a dictionary.

View file

@ -1,9 +1,13 @@
# the generic time format is based on python time.strftime import typing as t
# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime # https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import in_trie, new_trie from sqlglot.trie import in_trie, new_trie
def format_time(string, mapping, trie=None): def format_time(
string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None
) -> t.Optional[str]:
""" """
Converts a time string given a mapping. Converts a time string given a mapping.
@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None):
>>> format_time("%Y", {"%Y": "YYYY"}) >>> format_time("%Y", {"%Y": "YYYY"})
'YYYY' 'YYYY'
mapping: Dictionary of time format to target time format Args:
trie: Optional trie, can be passed in for performance mapping: dictionary of time format to target time format.
trie: optional trie, can be passed in for performance.
Returns:
The converted time string.
""" """
if not string: if not string:
return None return None
start = 0 start = 0
end = 1 end = 1
size = len(string) size = len(string)

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from enum import auto from enum import auto
from sqlglot.helper import AutoName from sqlglot.helper import AutoName
@ -27,6 +30,7 @@ class TokenType(AutoName):
NOT = auto() NOT = auto()
EQ = auto() EQ = auto()
NEQ = auto() NEQ = auto()
NULLSAFE_EQ = auto()
AND = auto() AND = auto()
OR = auto() OR = auto()
AMP = auto() AMP = auto()
@ -36,12 +40,14 @@ class TokenType(AutoName):
TILDA = auto() TILDA = auto()
ARROW = auto() ARROW = auto()
DARROW = auto() DARROW = auto()
FARROW = auto()
HASH = auto()
HASH_ARROW = auto() HASH_ARROW = auto()
DHASH_ARROW = auto() DHASH_ARROW = auto()
LR_ARROW = auto() LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto() DOLLAR = auto()
PARAMETER = auto() PARAMETER = auto()
SESSION_PARAMETER = auto()
SPACE = auto() SPACE = auto()
BREAK = auto() BREAK = auto()
@ -73,7 +79,7 @@ class TokenType(AutoName):
NVARCHAR = auto() NVARCHAR = auto()
TEXT = auto() TEXT = auto()
BINARY = auto() BINARY = auto()
BYTEA = auto() VARBINARY = auto()
JSON = auto() JSON = auto()
TIMESTAMP = auto() TIMESTAMP = auto()
TIMESTAMPTZ = auto() TIMESTAMPTZ = auto()
@ -142,6 +148,7 @@ class TokenType(AutoName):
DESCRIBE = auto() DESCRIBE = auto()
DETERMINISTIC = auto() DETERMINISTIC = auto()
DISTINCT = auto() DISTINCT = auto()
DISTINCT_FROM = auto()
DISTRIBUTE_BY = auto() DISTRIBUTE_BY = auto()
DIV = auto() DIV = auto()
DROP = auto() DROP = auto()
@ -238,6 +245,7 @@ class TokenType(AutoName):
RETURNS = auto() RETURNS = auto()
RIGHT = auto() RIGHT = auto()
RLIKE = auto() RLIKE = auto()
ROLLBACK = auto()
ROLLUP = auto() ROLLUP = auto()
ROW = auto() ROW = auto()
ROWS = auto() ROWS = auto()
@ -287,37 +295,49 @@ class TokenType(AutoName):
class Token: class Token:
__slots__ = ("token_type", "text", "line", "col") __slots__ = ("token_type", "text", "line", "col", "comment")
@classmethod @classmethod
def number(cls, number): def number(cls, number: int) -> Token:
"""Returns a NUMBER token with `number` as its text."""
return cls(TokenType.NUMBER, str(number)) return cls(TokenType.NUMBER, str(number))
@classmethod @classmethod
def string(cls, string): def string(cls, string: str) -> Token:
"""Returns a STRING token with `string` as its text."""
return cls(TokenType.STRING, string) return cls(TokenType.STRING, string)
@classmethod @classmethod
def identifier(cls, identifier): def identifier(cls, identifier: str) -> Token:
"""Returns an IDENTIFIER token with `identifier` as its text."""
return cls(TokenType.IDENTIFIER, identifier) return cls(TokenType.IDENTIFIER, identifier)
@classmethod @classmethod
def var(cls, var): def var(cls, var: str) -> Token:
"""Returns an VAR token with `var` as its text."""
return cls(TokenType.VAR, var) return cls(TokenType.VAR, var)
def __init__(self, token_type, text, line=1, col=1): def __init__(
self,
token_type: TokenType,
text: str,
line: int = 1,
col: int = 1,
comment: t.Optional[str] = None,
) -> None:
self.token_type = token_type self.token_type = token_type
self.text = text self.text = text
self.line = line self.line = line
self.col = max(col - len(text), 1) self.col = max(col - len(text), 1)
self.comment = comment
def __repr__(self): def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>" return f"<Token {attributes}>"
class _Tokenizer(type): class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs) klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
@ -325,27 +345,29 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._ESCAPES = set(klass.ESCAPES)
klass._COMMENTS = dict( klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
) )
klass.KEYWORD_TRIE = new_trie( klass.KEYWORD_TRIE = new_trie(
key.upper() key.upper()
for key, value in { for key in {
**klass.KEYWORDS, **klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass._COMMENTS}, **{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass._QUOTES}, **{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
}.items() }
if " " in key or any(single in key for single in klass.SINGLE_TOKENS) if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
) )
return klass return klass
@staticmethod @staticmethod
def _delimeter_list_to_dict(list): def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
"*": TokenType.STAR, "*": TokenType.STAR,
"~": TokenType.TILDA, "~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER, "?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION,
"@": TokenType.PARAMETER, "@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else # used for breaking a var like x'y' but nothing else
# the token type doesn't matter # the token type doesn't matter
"'": TokenType.QUOTE, "'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER, "`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER, '"': TokenType.IDENTIFIER,
"#": TokenType.HASH,
} }
QUOTES = ["'"] QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
BIT_STRINGS = [] BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS = [] HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
BYTE_STRINGS = [] BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS = ['"'] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
ESCAPE = "'" ESCAPES = ["'"]
KEYWORDS = { KEYWORDS = {
"/*+": TokenType.HINT, "/*+": TokenType.HINT,
@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
"<=": TokenType.LTE, "<=": TokenType.LTE,
"<>": TokenType.NEQ, "<>": TokenType.NEQ,
"!=": TokenType.NEQ, "!=": TokenType.NEQ,
"<=>": TokenType.NULLSAFE_EQ,
"->": TokenType.ARROW, "->": TokenType.ARROW,
"->>": TokenType.DARROW, "->>": TokenType.DARROW,
"=>": TokenType.FARROW,
"#>": TokenType.HASH_ARROW, "#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW, "#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW, "<->": TokenType.LR_ARROW,
@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESCRIBE": TokenType.DESCRIBE, "DESCRIBE": TokenType.DESCRIBE,
"DETERMINISTIC": TokenType.DETERMINISTIC, "DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT, "DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV, "DIV": TokenType.DIV,
"DROP": TokenType.DROP, "DROP": TokenType.DROP,
@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RETURNS": TokenType.RETURNS, "RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT, "RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE, "RLIKE": TokenType.RLIKE,
"ROLLBACK": TokenType.ROLLBACK,
"ROLLUP": TokenType.ROLLUP, "ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW, "ROW": TokenType.ROW,
"ROWS": TokenType.ROWS, "ROWS": TokenType.ROWS,
@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TEXT": TokenType.TEXT, "TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT, "CLOB": TokenType.TEXT,
"BINARY": TokenType.BINARY, "BINARY": TokenType.BINARY,
"BLOB": TokenType.BINARY, "BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.BINARY, "BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SET, TokenType.SET,
TokenType.SHOW, TokenType.SHOW,
TokenType.TRUNCATE, TokenType.TRUNCATE,
TokenType.USE,
TokenType.VACUUM, TokenType.VACUUM,
TokenType.ROLLBACK,
} }
# handle numeric literals like in hive (3L = BIGINT) # handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS = {} NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE = None ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")] COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled KEYWORD_TRIE = None # autofilled
@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
"_current", "_current",
"_line", "_line",
"_col", "_col",
"_comment",
"_char", "_char",
"_end", "_end",
"_peek", "_peek",
"_prev_token_line",
"_prev_token_comment",
"_prev_token_type", "_prev_token_type",
"_replace_backslash",
) )
def __init__(self): def __init__(self) -> None:
""" self._replace_backslash = "\\" in self._ESCAPES # type: ignore
Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
"""
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.sql = "" self.sql = ""
self.size = 0 self.size = 0
self.tokens = [] self.tokens: t.List[Token] = []
self._start = 0 self._start = 0
self._current = 0 self._current = 0
self._line = 1 self._line = 1
self._col = 1 self._col = 1
self._comment = None
self._char = None self._char = None
self._end = None self._end = None
self._peek = None self._peek = None
self._prev_token_line = -1
self._prev_token_comment = None
self._prev_token_type = None self._prev_token_type = None
def tokenize(self, sql): def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset() self.reset()
self.sql = sql self.sql = sql
self.size = len(sql) self.size = len(sql)
@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
if not self._char: if not self._char:
break break
white_space = self.WHITE_SPACE.get(self._char) white_space = self.WHITE_SPACE.get(self._char) # type: ignore
identifier_end = self._IDENTIFIERS.get(self._char) identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
if white_space: if white_space:
if white_space == TokenType.BREAK: if white_space == TokenType.BREAK:
self._col = 1 self._col = 1
self._line += 1 self._line += 1
elif self._char.isdigit(): elif self._char.isdigit(): # type:ignore
self._scan_number() self._scan_number()
elif identifier_end: elif identifier_end:
self._scan_identifier(identifier_end) self._scan_identifier(identifier_end)
@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_keywords() self._scan_keywords()
return self.tokens return self.tokens
def _chars(self, size): def _chars(self, size: int) -> str:
if size == 1: if size == 1:
return self._char return self._char # type: ignore
start = self._current - 1 start = self._current - 1
end = start + size end = start + size
if end <= self.size: if end <= self.size:
return self.sql[start:end] return self.sql[start:end]
return "" return ""
def _advance(self, i=1): def _advance(self, i: int = 1) -> None:
self._col += i self._col += i
self._current += i self._current += i
self._end = self._current >= self.size self._end = self._current >= self.size # type: ignore
self._char = self.sql[self._current - 1] self._char = self.sql[self._current - 1] # type: ignore
self._peek = self.sql[self._current] if self._current < self.size else "" self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
@property @property
def _text(self): def _text(self) -> str:
return self.sql[self._start : self._current] return self.sql[self._start : self._current]
def _add(self, token_type, text=None): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_type = token_type self._prev_token_line = self._line
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col)) self._prev_token_comment = self._comment
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
self._comment,
)
)
self._comment = None
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON): if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current self._start = self._current
while not self._end and self._peek != ";": while not self._end and self._peek != ";":
self._advance() self._advance()
if self._start < self._current: if self._start < self._current:
self._add(TokenType.STRING) self._add(TokenType.STRING)
def _scan_keywords(self): def _scan_keywords(self) -> None:
size = 0 size = 0
word = None word = None
chars = self._text chars = self._text
@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip: if skip:
result = 1 result = 1
else: else:
result, trie = in_trie(trie, char.upper()) result, trie = in_trie(trie, char.upper()) # type: ignore
if result == 0: if result == 0:
break break
@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
else: else:
skip = True skip = True
else: else:
chars = None chars = None # type: ignore
if not word: if not word:
if self._char in self.SINGLE_TOKENS: if self._char in self.SINGLE_TOKENS:
token = self.SINGLE_TOKENS[self._char] self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
if token == TokenType.ANNOTATION:
self._scan_annotation()
return
self._add(token)
return return
self._scan_var() self._scan_var()
return return
@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance(size - 1) self._advance(size - 1)
self._add(self.KEYWORDS[word.upper()]) self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start): def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS: if comment_start not in self._COMMENTS: # type: ignore
return False return False
comment_end = self._COMMENTS[comment_start] comment_start_line = self._line
comment_start_size = len(comment_start)
comment_end = self._COMMENTS[comment_start] # type: ignore
if comment_end: if comment_end:
comment_end_size = len(comment_end) comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end: while not self._end and self._chars(comment_end_size) != comment_end:
self._advance() self._advance()
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._advance(comment_end_size - 1) self._advance(comment_end_size - 1)
else: else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance() self._advance()
self._comment = self._text[comment_start_size:] # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._comment = None
return True return True
def _scan_annotation(self): def _scan_number(self) -> None:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
self._advance()
self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self):
if self._char == "0": if self._char == "0":
peek = self._peek.upper() peek = self._peek.upper() # type: ignore
if peek == "B": if peek == "B":
return self._scan_bits() return self._scan_bits()
elif peek == "X": elif peek == "X":
@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0 scientific = 0
while True: while True:
if self._peek.isdigit(): if self._peek.isdigit(): # type: ignore
self._advance() self._advance()
elif self._peek == "." and not decimal: elif self._peek == "." and not decimal:
decimal = True decimal = True
@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1: elif self._peek in ("-", "+") and scientific == 1:
scientific += 1 scientific += 1
self._advance() self._advance()
elif self._peek.upper() == "E" and not scientific: elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1 scientific += 1
self._advance() self._advance()
elif self._peek.isalpha(): elif self._peek.isalpha(): # type: ignore
self._add(TokenType.NUMBER) self._add(TokenType.NUMBER)
literal = [] literal = []
while self._peek.isalpha(): while self._peek.isalpha(): # type: ignore
literal.append(self._peek.upper()) literal.append(self._peek.upper()) # type: ignore
self._advance() self._advance()
literal = "".join(literal) literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type: if token_type:
self._add(TokenType.DCOLON, "::") self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal) return self._add(token_type, literal) # type: ignore
return self._advance(-len(literal)) return self._advance(-len(literal))
else: else:
return self._add(TokenType.NUMBER) return self._add(TokenType.NUMBER)
def _scan_bits(self): def _scan_bits(self) -> None:
self._advance() self._advance()
value = self._extract_value() value = self._extract_value()
try: try:
@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError: except ValueError:
self._add(TokenType.IDENTIFIER) self._add(TokenType.IDENTIFIER)
def _scan_hex(self): def _scan_hex(self) -> None:
self._advance() self._advance()
value = self._extract_value() value = self._extract_value()
try: try:
@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError: except ValueError:
self._add(TokenType.IDENTIFIER) self._add(TokenType.IDENTIFIER)
def _extract_value(self): def _extract_value(self) -> str:
while True: while True:
char = self._peek.strip() char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS: if char and char not in self.SINGLE_TOKENS:
self._advance() self._advance()
else: else:
@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text return self._text
def _scan_string(self, quote): def _scan_string(self, quote: str) -> bool:
quote_end = self._QUOTES.get(quote) quote_end = self._QUOTES.get(quote) # type: ignore
if quote_end is None: if quote_end is None:
return False return False
self._advance(len(quote)) self._advance(len(quote))
text = self._extract_string(quote_end) text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text text = text.replace("\\\\", "\\") if self._replace_backslash else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
self._add(TokenType.STRING, text) self._add(TokenType.STRING, text)
return True return True
# X'1234, b'0110', E'\\\\\' etc. # X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start): def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS: if string_start in self._HEX_STRINGS: # type: ignore
delimiters = self._HEX_STRINGS delimiters = self._HEX_STRINGS # type: ignore
token_type = TokenType.HEX_STRING token_type = TokenType.HEX_STRING
base = 16 base = 16
elif string_start in self._BIT_STRINGS: elif string_start in self._BIT_STRINGS: # type: ignore
delimiters = self._BIT_STRINGS delimiters = self._BIT_STRINGS # type: ignore
token_type = TokenType.BIT_STRING token_type = TokenType.BIT_STRING
base = 2 base = 2
elif string_start in self._BYTE_STRINGS: elif string_start in self._BYTE_STRINGS: # type: ignore
delimiters = self._BYTE_STRINGS delimiters = self._BYTE_STRINGS # type: ignore
token_type = TokenType.BYTE_STRING token_type = TokenType.BYTE_STRING
base = None base = None
else: else:
@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
try: try:
self._add(token_type, f"{int(text, base)}") self._add(token_type, f"{int(text, base)}")
except: except:
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}") raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
return True return True
def _scan_identifier(self, identifier_end): def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end: while self._peek != identifier_end:
if self._end: if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance() self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1]) self._add(TokenType.IDENTIFIER, self._text[1:-1])
def _scan_var(self): def _scan_var(self) -> None:
while True: while True:
char = self._peek.strip() char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS: if char and char not in self.SINGLE_TOKENS:
self._advance() self._advance()
else: else:
@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
) )
def _extract_string(self, delimiter): def _extract_string(self, delimiter: str) -> str:
text = "" text = ""
delim_size = len(delimiter) delim_size = len(delimiter)
while True: while True:
if self._char == self.ESCAPE and self._peek == delimiter: if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter text += delimiter
self._advance(2) self._advance(2)
else: else:
@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end: if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char text += self._char # type: ignore
self._advance() self._advance()
return text return text

View file

@ -1,7 +1,14 @@
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
from sqlglot import expressions as exp from sqlglot import expressions as exp
def unalias_group(expression): def unalias_group(expression: exp.Expression) -> exp.Expression:
""" """
Replace references to select aliases in GROUP BY clauses. Replace references to select aliases in GROUP BY clauses.
@ -9,6 +16,12 @@ def unalias_group(expression):
>>> import sqlglot >>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1' 'SELECT a AS b FROM x GROUP BY 1'
Args:
expression: the expression that will be transformed.
Returns:
The transformed expression.
""" """
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = { aliased_selects = {
@ -30,18 +43,19 @@ def unalias_group(expression):
return expression return expression
def preprocess(transforms, to_sql): def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
) -> t.Callable[[Generator, exp.Expression], str]:
""" """
Create a new transform function that can be used a value in `Generator.TRANSFORMS` Creates a new transform by chaining a sequence of transformations and converts the resulting
to convert expressions to SQL. expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
Args: Args:
transforms (list[(exp.Expression) -> exp.Expression]): transforms: sequence of transform functions. These will be called in order.
Sequence of transform functions. These will be called in order. to_sql: final transform that converts the resulting expression to a SQL string.
to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
Final transform that converts the resulting expression to a SQL string.
Returns: Returns:
(sqlglot.generator.Generator, exp.Expression) -> str:
Function that can be used as a generator transform. Function that can be used as a generator transform.
""" """
@ -54,12 +68,10 @@ def preprocess(transforms, to_sql):
return _to_sql return _to_sql
def delegate(attr): def delegate(attr: str) -> t.Callable:
""" """
Create a new method that delegates to `attr`. Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
functions that delegate to existing generator methods.
This is useful for creating `Generator.TRANSFORMS` functions that delegate
to existing generator methods.
""" """
def _transform(self, *args, **kwargs): def _transform(self, *args, **kwargs):

View file

@ -1,5 +1,26 @@
def new_trie(keywords): import typing as t
trie = {}
key = t.Sequence[t.Hashable]
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
The trie is represented as a sequence of nested dictionaries keyed by either single character
strings, or by 0, which is used to designate that a keyword is in the trie.
Example:
>>> new_trie(["bla", "foo", "blab"])
{'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}}
Args:
keywords: the keywords to create the trie from.
Returns:
The trie corresponding to `keywords`.
"""
trie: t.Dict = {}
for key in keywords: for key in keywords:
current = trie current = trie
@ -11,7 +32,28 @@ def new_trie(keywords):
return trie return trie
def in_trie(trie, key): def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
"""
Checks whether a key is in a trie.
Examples:
>>> in_trie(new_trie(["cat"]), "bob")
(0, {'c': {'a': {'t': {0: True}}}})
>>> in_trie(new_trie(["cat"]), "ca")
(1, {'t': {0: True}})
>>> in_trie(new_trie(["cat"]), "cat")
(2, {0: True})
Args:
trie: the trie to be searched.
key: the target key.
Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
"""
if not key: if not key:
return (0, trie) return (0, trie)

View file

@ -1,9 +1,9 @@
import sys
import typing as t import typing as t
import unittest import unittest
import warnings import warnings
import sqlglot import sqlglot
from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf( @unittest.skipIf(
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
) )
class DataFrameValidator(unittest.TestCase): class DataFrameValidator(unittest.TestCase):
spark = None spark = None
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes` # This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() cls.spark = (
SparkSession.builder.master("local[*]")
.appName("Unit-tests")
.config(conf=config)
.getOrCreate()
)
cls.spark.sparkContext.setLogLevel("ERROR") cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession() cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType( cls.spark_employee_schema = types.StructType(
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
) )
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[ [
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField(
"employee_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2), (4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100), (5, "Hugo", "Reyes", 29, 100),
] ]
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) cls.df_employee = cls.spark.createDataFrame(
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) data=employee_data, schema=cls.spark_employee_schema
)
cls.dfs_employee = cls.sqlglot.createDataFrame(
data=employee_data, schema=cls.sqlglot_employee_schema
)
cls.df_employee.createOrReplaceTempView("employee") cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType( cls.spark_store_schema = types.StructType(
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[ [
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
] ]
) )
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000), (2, "Arrow", 2, 2000),
] ]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) cls.dfs_store = cls.sqlglot.createDataFrame(
data=store_data, schema=cls.sqlglot_store_schema
)
cls.df_store.createOrReplaceTempView("store") cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType( cls.spark_district_schema = types.StructType(
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
) )
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[ [
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField(
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), "district_id", sqlglotSparkTypes.IntegerType(), False
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), ),
sqlglotSparkTypes.StructField(
"district_name", sqlglotSparkTypes.StringType(), False
),
sqlglotSparkTypes.StructField(
"manager_name", sqlglotSparkTypes.StringType(), False
),
] ]
) )
district_data = [ district_data = [
(1, "Temple", "Dogen"), (1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"), (2, "Lighthouse", "Jacob"),
] ]
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) cls.df_district = cls.spark.createDataFrame(
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) data=district_data, schema=cls.spark_district_schema
)
cls.dfs_district = cls.sqlglot.createDataFrame(
data=district_data, schema=cls.sqlglot_district_schema
)
cls.df_district.createOrReplaceTempView("district") cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema)

View file

@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self): def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select( df_employee = self.df_spark_employee.alias("df_employee").select(
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname self.df_spark_employee["employee_id"],
F.col("df_employee.fname"),
self.df_spark_employee.lname,
) )
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname self.df_sqlglot_employee["employee_id"],
SF.col("dfs_employee.fname"),
self.df_sqlglot_employee.lname,
) )
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self): def test_case_when_otherwise(self):
df = self.df_spark_employee.select( df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) F.when(
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
F.lit("between 40 and 60"),
)
.when(F.col("age") < F.lit(40), "less than 40") .when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60") .otherwise("greater than 60")
) )
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) SF.when(
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
SF.lit("between 40 and 60"),
)
.when(SF.col("age") < SF.lit(40), "less than 40") .when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60") .otherwise("greater than 60")
) )
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self): def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select( df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( F.when(
F.col("age") < F.lit(40), "less than 40" (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
) F.lit("between 40 and 60"),
).when(F.col("age") < F.lit(40), "less than 40")
) )
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( SF.when(
SF.col("age") < SF.lit(40), "less than 40" (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
) SF.lit("between 40 and 60"),
).when(SF.col("age") < SF.lit(40), "less than 40")
) )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self): def test_where_clause_multiple_and(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
)
dfs_employee = self.df_sqlglot_employee.where( dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
) )
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self): def test_where_clause_multiple_or(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
)
dfs_employee = self.df_sqlglot_employee.where( dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
) )
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) df_employee = self.df_spark_employee.where(
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) df_employee = self.df_spark_employee.where(
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) df_employee = self.df_spark_employee.where(
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where( df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
) )
dfs_employee = self.df_sqlglot_employee.where( dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) self.df_sqlglot_employee["age"] * SF.lit(0.5)
== self.df_sqlglot_employee["age"] / SF.lit(2)
) )
self.compare_spark_with_sqlglot(df_employee, dfs_employee) self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self): def test_join_inner(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="inner"
).select(
self.df_spark_employee.employee_id, self.df_spark_employee.employee_id,
self.df_spark_employee["fname"], self.df_spark_employee["fname"],
F.col("lname"), F.col("lname"),
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name, self.df_spark_store.store_name,
self.df_spark_store["num_sales"], self.df_spark_store["num_sales"],
) )
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="inner"
).select(
self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"], self.df_sqlglot_employee["fname"],
SF.col("lname"), SF.col("lname"),
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined) self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self): def test_join_inner_no_select(self):
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( df_joined = self.df_spark_employee.select(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" F.col("store_id"), F.col("fname"), F.col("lname")
).join(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
on=["store_id"],
how="inner",
) )
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( dfs_joined = self.df_sqlglot_employee.select(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" SF.col("store_id"), SF.col("fname"), SF.col("lname")
).join(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
on=["store_id"],
how="inner",
) )
self.compare_spark_with_sqlglot(df_joined, dfs_joined) self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self): def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join( df_joined = self.df_spark_employee.join(
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" self.df_spark_store,
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
how="inner",
).select( ).select(
self.df_spark_employee.employee_id, self.df_spark_employee.employee_id,
self.df_spark_employee["fname"], self.df_spark_employee["fname"],
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"], self.df_spark_store["num_sales"],
) )
dfs_joined = self.df_sqlglot_employee.join( dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" self.df_sqlglot_store,
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
how="inner",
).select( ).select(
self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"], self.df_sqlglot_employee["fname"],
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined) self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self): def test_join_full_outer(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="full_outer"
).select(
self.df_spark_employee.employee_id, self.df_spark_employee.employee_id,
self.df_spark_employee["fname"], self.df_spark_employee["fname"],
F.col("lname"), F.col("lname"),
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name, self.df_spark_store.store_name,
self.df_spark_store["num_sales"], self.df_spark_store["num_sales"],
) )
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="full_outer"
).select(
self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"], self.df_sqlglot_employee["fname"],
SF.col("lname"), SF.col("lname"),
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self): def test_triple_join(self):
df = ( df = (
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) self.df_employee.join(
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
)
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select( .select(
self.df_employee.employee_id, self.df_employee.employee_id,
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
) )
) )
dfs = ( dfs = (
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) self.dfs_employee.join(
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
)
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select( .select(
self.dfs_employee.employee_id, self.dfs_employee.employee_id,
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self): def test_join_select_and_select_start(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( df = self.df_spark_employee.select(
self.df_spark_store, "store_id", "inner" F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
) ).join(self.df_spark_store, "store_id", "inner")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( dfs = self.df_sqlglot_employee.select(
self.df_sqlglot_store, "store_id", "inner" SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
) ).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = ( dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) .unionAll(
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
)
) )
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self): def test_union_by_name(self):
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( df = self.df_spark_employee.select(
F.col("employee_id"), F.col("fname"), F.col("lname")
).unionByName(
self.df_spark_store.select( self.df_spark_store.select(
F.col("store_name").alias("lname"), F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"), F.col("store_id").alias("employee_id"),
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
) )
) )
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( dfs = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
).unionByName(
self.df_sqlglot_store.select( self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"), SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"), SF.col("store_id").alias("employee_id"),
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self): def test_order_by_default(self):
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales"))
.orderBy(F.col("district_id"))
)
dfs = ( dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales"))
.orderBy(SF.col("district_id"))
) )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = ( df = (
self.df_spark_store.groupBy(F.col("district_id")) self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales")) .agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) .orderBy(
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
)
) )
dfs = ( dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id")) self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales")) .agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) .orderBy(
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
)
) )
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = ( df = (
self.df_spark_store.groupBy(F.col("district_id")) self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales")) .agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) .orderBy(
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
)
) )
dfs = ( dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id")) self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales")) .agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) .orderBy(
SF.when(
SF.col("district_id") == SF.lit(1), SF.col("district_id")
).desc_nulls_first()
)
) )
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self): def test_intersect(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( df_employee_duplicate = self.df_spark_employee.select(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) F.col("employee_id"), F.col("store_id")
) ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( df_store_duplicate = self.df_spark_store.select(
self.df_spark_store.select(F.col("store_id"), F.col("district_id")) F.col("store_id"), F.col("district_id")
) ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate) df = df_employee_duplicate.intersect(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( dfs_employee_duplicate = self.df_sqlglot_employee.select(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) SF.col("employee_id"), SF.col("store_id")
) ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( dfs_store_duplicate = self.df_sqlglot_store.select(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) SF.col("store_id"), SF.col("district_id")
) ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self): def test_intersect_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( df_employee_duplicate = self.df_spark_employee.select(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) F.col("employee_id"), F.col("store_id")
) ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( df_store_duplicate = self.df_spark_store.select(
self.df_spark_store.select(F.col("store_id"), F.col("district_id")) F.col("store_id"), F.col("district_id")
) ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate) df = df_employee_duplicate.intersectAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( dfs_employee_duplicate = self.df_sqlglot_employee.select(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) SF.col("employee_id"), SF.col("store_id")
) ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( dfs_store_duplicate = self.df_sqlglot_store.select(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) SF.col("store_id"), SF.col("district_id")
) ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self): def test_except_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( df_employee_duplicate = self.df_spark_employee.select(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) F.col("employee_id"), F.col("store_id")
) ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( df_store_duplicate = self.df_spark_store.select(
self.df_spark_store.select(F.col("store_id"), F.col("district_id")) F.col("store_id"), F.col("district_id")
) ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate) df = df_employee_duplicate.exceptAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( dfs_employee_duplicate = self.df_sqlglot_employee.select(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) SF.col("employee_id"), SF.col("store_id")
) ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( dfs_store_duplicate = self.df_sqlglot_store.select(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) SF.col("store_id"), SF.col("district_id")
) ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self): def test_drop_na_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).dropna()
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2) ).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2) ).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age") ).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age") ).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self): def test_dropna_na_function(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.drop()
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self): def test_fillna_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).fillna(100)
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self): def test_fillna_na_func(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.fill(100)
dfs = self.df_sqlglot_employee.select( dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self): def test_replace_basic(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100 to_replace=37, value=100
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self): def test_replace_mapping(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
{37: 100}
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
{37: 100}
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100 to_replace=37, value=100
) )
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( dfs = self.df_sqlglot_employee.select(
to_replace=37, value=100 SF.col("age"), SF.lit(37).alias("test_col")
) ).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again" "first_name", "first_name_again"
) )
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( dfs = self.df_sqlglot_employee.select(
"first_name", "first_name_again" SF.col("fname").alias("first_name")
) ).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self): def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
"age"
)
self.compare_spark_with_sqlglot(df, dfs) self.compare_spark_with_sqlglot(df, dfs)
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select( df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
) )
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) df_sqlglot_store_cols = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("store_name")
)
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age, df_sqlglot_employee_cols.age,
) )

View file

@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON ON
e.store_id = s.store_id e.store_id = s.store_id
""" """
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) df = (
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) self.spark.sql(query)
.groupBy(F.col("store_id"))
.agg(F.countDistinct(F.col("employee_id")))
)
dfs = (
self.sqlglot.sql(query)
.groupBy(SF.col("store_id"))
.agg(SF.countDistinct(SF.col("employee_id")))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2), (4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100), (5, "Hugo", "Reyes", 29, 100),
] ]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) self.df_employee = self.spark.createDataFrame(
data=employee_data, schema=self.employee_schema
)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): def compare_sql(
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty) actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements expected_statements = (
[expected_statements] if isinstance(expected_statements, str) else expected_statements
)
self.assertEqual(len(expected_statements), len(actual_sqls)) self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls): for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)

View file

@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self): def test_and(self):
self.assertEqual( self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() "cola = colb AND colc = cold",
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
) )
def test_or(self): def test_or(self):
self.assertEqual( self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() "cola = colb OR colc = cold",
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
) )
def test_mod(self): def test_mod(self):
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self): def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) self.assertEqual(
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
)
self.assertEqual( self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual( self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),
) )
def test_over(self): def test_over(self):

View file

@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self): def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) self.assertEqual(
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
)
def test_cache(self): def test_cache(self):
df = self.df_employee.select("fname").cache() df = self.df_employee.select("fname").cache()

View file

@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes") col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
)
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
)
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual( self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
col_no_slide.sql(),
) )
def test_session_window(self): def test_session_window(self):
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self): def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_json("cola", "cola INT") col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self): def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) self.assertEqual(
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola") col_no_option = SF.schema_of_json("cola")
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola")) col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql()) self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort( col_comparator = SF.array_sort(
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) "cola",
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
SF.length(y) - SF.length(x)
),
) )
self.assertEqual( self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self): def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_csv("cola", "cola INT") col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) self.assertEqual(
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
)
def test_exists(self): def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0) col_str = SF.exists("cola", lambda x: x % 2 == 0)
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) col_custom_names = SF.filter(
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
)
self.assertEqual( self.assertEqual(
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
col_custom_names.sql(),
) )
def test_zip_with(self): def test_zip_with(self):
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) self.assertEqual(
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
)
def test_transform_keys(self): def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) self.assertEqual(
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
)
def test_map_filter(self): def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v) col_str = SF.map_filter("cola", lambda k, v: k > v)

View file

@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self): def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = ( expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self): def test_cdf_row_mixed_primitives(self):
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query) df = self.spark.sql(query)
self.assertIn( self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
df.sql(pretty=False),
) )
@mock.patch("sqlglot.schema", MappingSchema()) @mock.patch("sqlglot.schema", MappingSchema())
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query) df = self.spark.sql(query)
expected = ( expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
self.compare_sql(df, expected) self.compare_sql(df, expected)
def test_session_create_builder_patterns(self): def test_session_create_builder_patterns(self):

View file

@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString()) self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self): def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString()) self.assertEqual(
"map<int, string>",
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
)
def test_struct_field(self): def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())

View file

@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self): def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual( self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(),
)
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(),
)
rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(),
) )
def test_window_range_unbounded(self): def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual( self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(),
) )
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual( self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(),
)
range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(),
) )

View file

@ -157,6 +157,14 @@ class TestBigQuery(Validator):
}, },
) )
self.validate_all(
"DIV(x, y)",
write={
"bigquery": "DIV(x, y)",
"duckdb": "CAST(x / y AS INT)",
},
)
self.validate_identity( self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
) )
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
) )
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
)

View file

@ -18,7 +18,6 @@ class TestClickhouse(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
}, },
) )
self.validate_all( self.validate_all(
"CAST(1 AS NULLABLE(Int64))", "CAST(1 AS NULLABLE(Int64))",
write={ write={
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
}, },
) )
self.validate_all(
"SELECT x #! comment",
write={"": "SELECT x /* comment */"},
)

View file

@ -22,7 +22,8 @@ class TestDatabricks(Validator):
}, },
) )
self.validate_all( self.validate_all(
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} "SELECT DATEDIFF('end', 'start')",
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
) )
self.validate_all( self.validate_all(
"SELECT DATE_ADD('2020-01-01', 1)", "SELECT DATE_ADD('2020-01-01', 1)",

View file

@ -1,20 +1,18 @@
import unittest import unittest
from sqlglot import ( from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
Dialect,
Dialects,
ErrorLevel,
UnsupportedError,
parse_one,
transpile,
)
class Validator(unittest.TestCase): class Validator(unittest.TestCase):
dialect = None dialect = None
def validate_identity(self, sql): def parse_one(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) return parse_one(sql, read=self.dialect)
def validate_identity(self, sql, write_sql=None):
expression = self.parse_one(sql)
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
def validate_all(self, sql, read=None, write=None, pretty=False): def validate_all(self, sql, read=None, write=None, pretty=False):
""" """
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL
""" """
expression = parse_one(sql, read=self.dialect) expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items(): for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"): with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual( self.assertEqual(
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
),
sql, sql,
) )
@ -83,10 +83,6 @@ class TestDialect(Validator):
) )
self.validate_all( self.validate_all(
"CAST(a AS BINARY(4))", "CAST(a AS BINARY(4))",
read={
"presto": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS VARBINARY(4))",
},
write={ write={
"bigquery": "CAST(a AS BINARY(4))", "bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))",
@ -103,6 +99,24 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS BINARY(4))", "starrocks": "CAST(a AS BINARY(4))",
}, },
) )
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY(4))",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
"postgres": "CAST(a AS BYTEA(4))",
"presto": "CAST(a AS VARBINARY(4))",
"redshift": "CAST(a AS VARBYTE(4))",
"snowflake": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS BLOB(4))",
"spark": "CAST(a AS BINARY(4))",
"starrocks": "CAST(a AS VARBINARY(4))",
},
)
self.validate_all( self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={ write={
@ -472,45 +486,57 @@ class TestDialect(Validator):
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'day')", "DATE_TRUNC('day', x)",
write={ write={
"mysql": "DATE(x)", "mysql": "DATE(x)",
"starrocks": "DATE(x)",
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'week')", "DATE_TRUNC('week', x)",
write={ write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'month')", "DATE_TRUNC('month', x)",
write={ write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'quarter')", "DATE_TRUNC('quarter', x)",
write={ write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'year')", "DATE_TRUNC('year', x)",
write={ write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
}, },
) )
self.validate_all( self.validate_all(
"DATE_TRUNC(x, 'millenium')", "DATE_TRUNC('millenium', x)",
write={ write={
"mysql": UnsupportedError, "mysql": UnsupportedError,
"starrocks": UnsupportedError, },
)
self.validate_all(
"DATE_TRUNC('year', x)",
read={
"starrocks": "DATE_TRUNC('year', x)",
},
write={
"starrocks": "DATE_TRUNC('year', x)",
},
)
self.validate_all(
"DATE_TRUNC(x, year)",
read={
"bigquery": "DATE_TRUNC(x, year)",
},
write={
"bigquery": "DATE_TRUNC(x, year)",
}, },
) )
self.validate_all( self.validate_all(
@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
}, },
) )
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
},
)
self.validate_all(
"TIMESTAMP('2022-01-01')",
write={
"mysql": "TIMESTAMP('2022-01-01')",
"starrocks": "TIMESTAMP('2022-01-01')",
"hive": "TIMESTAMP('2022-01-01')",
},
)
for unit in ("DAY", "MONTH", "YEAR"): for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all( self.validate_all(
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
) )
def test_limit(self): def test_limit(self):
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
)
self.validate_all( self.validate_all(
"SELECT x FROM y LIMIT 10", "SELECT x FROM y LIMIT 10",
write={ write={
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
}, },
) )
def test_nullsafe_eq(self):
self.validate_all(
"SELECT a IS NOT DISTINCT FROM b",
read={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
write={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
)
def test_nullsafe_neq(self):
self.validate_all(
"SELECT a IS DISTINCT FROM b",
read={
"postgres": "SELECT a IS DISTINCT FROM b",
},
write={
"mysql": "SELECT NOT a <=> b",
"postgres": "SELECT a IS DISTINCT FROM b",
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 /* arbitrary content,,, until end-of-line */",
read={
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
},
)
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
read={
"mysql": """SELECT # comment1
x, # comment2
y # comment3""",
"bigquery": """SELECT # comment1
x, # comment2
y # comment3""",
"clickhouse": """SELECT # comment1
x, # comment2
y # comment3""",
},
pretty=True,
)

View file

@ -1,3 +1,4 @@
from sqlglot import expressions as exp
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
@ -20,6 +21,52 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")
# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
self.validate_identity("SET GLOBAL max_connections = 1000")
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
self.validate_identity("SET PERSIST max_connections = 1000")
self.validate_identity("SET @@PERSIST.max_connections = 1000")
self.validate_identity("SET PERSIST_ONLY back_log = 100")
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
self.validate_identity(
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
)
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
self.validate_identity("SET CHARACTER SET 'utf8'")
self.validate_identity("SET CHARACTER SET utf8")
self.validate_identity("SET CHARACTER SET DEFAULT")
self.validate_identity("SET NAMES 'utf8'")
self.validate_identity("SET NAMES DEFAULT")
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
def test_escape(self):
self.validate_all(
r"'a \' b '' '",
write={
"mysql": r"'a '' b '' '",
"spark": r"'a \' b \' '",
},
)
def test_introducers(self): def test_introducers(self):
self.validate_all( self.validate_all(
@ -115,14 +162,6 @@ class TestMySQL(Validator):
}, },
) )
def test_hash_comments(self):
self.validate_all(
"SELECT 1 # arbitrary content,,, until end-of-line",
write={
"mysql": "SELECT 1",
},
)
def test_mysql(self): def test_mysql(self):
self.validate_all( self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
}, },
pretty=True, pretty=True,
) )
def test_show_simple(self):
for key, write_key in [
("BINARY LOGS", "BINARY LOGS"),
("MASTER LOGS", "BINARY LOGS"),
("STORAGE ENGINES", "ENGINES"),
("ENGINES", "ENGINES"),
("EVENTS", "EVENTS"),
("MASTER STATUS", "MASTER STATUS"),
("PLUGINS", "PLUGINS"),
("PRIVILEGES", "PRIVILEGES"),
("PROFILES", "PROFILES"),
("REPLICAS", "REPLICAS"),
("SLAVE HOSTS", "REPLICAS"),
]:
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, write_key)
def test_show_events(self):
for key in ["BINLOG", "RELAYLOG"]:
show = self.validate_identity(f"SHOW {key} EVENTS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, f"{key} EVENTS")
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
self.assertEqual(show.text("log"), "log")
self.assertEqual(show.text("position"), "1")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
self.assertEqual(show.text("limit"), "1")
self.assertIsNone(show.args.get("offset"))
def test_show_like_or_where(self):
for key, write_key in [
("CHARSET", "CHARACTER SET"),
("CHARACTER SET", "CHARACTER SET"),
("COLLATION", "COLLATION"),
("DATABASES", "DATABASES"),
("FUNCTION STATUS", "FUNCTION STATUS"),
("PROCEDURE STATUS", "PROCEDURE STATUS"),
("GLOBAL STATUS", "GLOBAL STATUS"),
("SESSION STATUS", "STATUS"),
("STATUS", "STATUS"),
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
("SESSION VARIABLES", "VARIABLES"),
("VARIABLES", "VARIABLES"),
]:
expected_name = write_key.strip("GLOBAL").strip()
template = "SHOW {}"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, expected_name)
template = "SHOW {} LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_columns(self):
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "COLUMNS")
self.assertEqual(show.text("target"), "tbl_name")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.text("target"), "tbl_name")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_show_name(self):
for key in [
"CREATE DATABASE",
"CREATE EVENT",
"CREATE FUNCTION",
"CREATE PROCEDURE",
"CREATE TABLE",
"CREATE TRIGGER",
"CREATE VIEW",
"FUNCTION CODE",
"PROCEDURE CODE",
]:
show = self.validate_identity(f"SHOW {key} foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
self.assertEqual(show.text("target"), "foo")
def test_show_grants(self):
show = self.validate_identity(f"SHOW GRANTS FOR foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "GRANTS")
self.assertEqual(show.text("target"), "foo")
def test_show_engine(self):
show = self.validate_identity("SHOW ENGINE foo STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertFalse(show.args["mutex"])
show = self.validate_identity("SHOW ENGINE foo MUTEX")
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertTrue(show.args["mutex"])
def test_show_errors(self):
for key in ["ERRORS", "WARNINGS"]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
def test_show_index(self):
show = self.validate_identity("SHOW INDEX FROM foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "INDEX")
self.assertEqual(show.text("target"), "foo")
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
self.assertEqual(show.text("db"), "bar")
def test_show_db_like_or_where_sql(self):
for key in [
"OPEN TABLES",
"TABLE STATUS",
"TRIGGERS",
]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} FROM db_name")
self.assertEqual(show.name, key)
self.assertEqual(show.text("db"), "db_name")
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_processlist(self):
show = self.validate_identity("SHOW PROCESSLIST")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROCESSLIST")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL PROCESSLIST")
self.assertEqual(show.name, "PROCESSLIST")
self.assertTrue(show.args["full"])
def test_show_profile(self):
show = self.validate_identity("SHOW PROFILE")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROFILE")
show = self.validate_identity("SHOW PROFILE BLOCK IO")
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
show = self.validate_identity(
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
)
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
self.assertEqual(show.text("query"), "1")
self.assertEqual(show.text("offset"), "2")
self.assertEqual(show.text("limit"), "3")
def test_show_replica_status(self):
show = self.validate_identity("SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
self.assertEqual(show.text("channel"), "channel_name")
def test_show_tables(self):
show = self.validate_identity("SHOW TABLES")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "TABLES")
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_set_variable(self):
cmd = self.parse_one("SET SESSION x = 1")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "SESSION")
self.assertIsInstance(item.this, exp.EQ)
self.assertEqual(item.this.left.name, "x")
self.assertEqual(item.this.right.name, "1")
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "")
self.assertIsInstance(item.this, exp.EQ)
self.assertIsInstance(item.this.left, exp.SessionParameter)
self.assertIsInstance(item.this.right, exp.SessionParameter)
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "NAMES")
self.assertEqual(item.name, "charset_name")
self.assertEqual(item.text("collate"), "collation_name")
cmd = self.parse_one("SET CHARSET DEFAULT")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "CHARACTER SET")
self.assertEqual(item.this.name, "DEFAULT")
cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)

View file

@ -8,7 +8,9 @@ class TestPostgres(Validator):
def test_ddl(self): def test_ddl(self):
self.validate_all( self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, write={
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
) )
self.validate_all( self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@ -59,15 +61,27 @@ class TestPostgres(Validator):
def test_postgres(self): def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") self.validate_identity(
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') )
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") self.validate_identity(
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
@ -75,7 +89,7 @@ class TestPostgres(Validator):
self.validate_all( self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)", "CREATE TABLE x (a UUID, b BYTEA)",
write={ write={
"duckdb": "CREATE TABLE x (a UUID, b BINARY)", "duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
"presto": "CREATE TABLE x (a UUID, b VARBINARY)", "presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)", "hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)",
@ -153,7 +167,9 @@ class TestPostgres(Validator):
) )
self.validate_all( self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, read={
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
},
) )
self.validate_all( self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
@ -169,11 +185,15 @@ class TestPostgres(Validator):
) )
self.validate_all( self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"}, read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
},
) )
self.validate_all( self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"}, read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
},
) )
self.validate_all( self.validate_all(
"'[1,2,3]'::json->2", "'[1,2,3]'::json->2",
@ -184,7 +204,8 @@ class TestPostgres(Validator):
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
) )
self.validate_all( self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""} """'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
) )
self.validate_all( self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""", """'{"x": {"y": 1}}'::json->'x'::json->'y'""",

View file

@ -61,4 +61,6 @@ class TestRedshift(Validator):
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
) )
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'") self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)

View file

@ -336,7 +336,8 @@ class TestSnowflake(Validator):
def test_table_literal(self): def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all( self.validate_all(
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""} r"""SELECT * FROM TABLE('MYTABLE')""",
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
) )
self.validate_all( self.validate_all(
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""}, write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
) )
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}) self.validate_all(
r"""SELECT * FROM TABLE($MYVAR)""",
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}) write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
)
self.validate_all( self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""} r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""",
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
) )
self.validate_all( self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""", r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""}, write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
) )
def test_flatten(self):
self.validate_all(
"""
select
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
f.value::varchar as operator
from cs.telescope.dag_report,
table(flatten(input=>split(operators, ','))) f
""",
write={
"snowflake": """SELECT
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
CAST(f.value AS VARCHAR) AS operator
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
},
pretty=True,
)
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
write={
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
},
)
self.validate_all(
"""
SELECT id as "ID",
f.value AS "Contact",
f1.value:type AS "Type",
f1.value:content AS "Details"
FROM persons p,
lateral flatten(input => p.c, path => 'contact') f,
lateral flatten(input => f.value:business) f1
""",
write={
"snowflake": """SELECT
id AS "ID",
f.value AS "Contact",
f1.value['type'] AS "Type",
f1.value['content'] AS "Details"
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
},
pretty=True,
)

View file

@ -284,4 +284,6 @@ TBLPROPERTIES (
) )
def test_iif(self): def test_iif(self):
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
)

View file

@ -6,3 +6,6 @@ class TestMySQL(Validator):
def test_identity(self): def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")

View file

@ -278,12 +278,19 @@ class TestTSQL(Validator):
def test_add_date(self): def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all( self.validate_all(
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} "SELECT DATEADD(year, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
)
self.validate_all(
"SELECT DATEADD(qq, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
) )
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all( self.validate_all(
"SELECT DATEADD(wk, 1, '2017/08/25')", "SELECT DATEADD(wk, 1, '2017/08/25')",
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, write={
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
},
) )
def test_date_diff(self): def test_date_diff(self):
@ -370,13 +377,21 @@ class TestTSQL(Validator):
"SELECT FORMAT(1000000.01,'###,###.###')", "SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
) )
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) self.validate_all(
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
)
self.validate_all( self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
) )
self.validate_all( self.validate_all(
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} "SELECT FORMAT(date_col, 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'm')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
)
self.validate_all(
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
) )
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})

View file

@ -523,6 +523,8 @@ DROP VIEW a.b
DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b DROP VIEW IF EXISTS a.b
SHOW TABLES SHOW TABLES
USE db
ROLLBACK
EXPLAIN SELECT * FROM x EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y) INSERT INTO x (SELECT * FROM y)
@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3) SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */
SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */
SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */

View file

@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x; SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x; SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
# dialect: starrocks
# execute: false
SELECT DATE_TRUNC('week', a) AS a FROM x;
SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
# dialect: bigquery
# execute: false
SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
-------------------------------------- --------------------------------------
-- Derived tables -- Derived tables
-------------------------------------- --------------------------------------

View file

@ -79,6 +79,15 @@ NULL;
NULL = NULL; NULL = NULL;
NULL; NULL;
NULL <=> NULL;
TRUE;
a IS NOT DISTINCT FROM a;
TRUE;
NULL IS DISTINCT FROM NULL;
FALSE;
NOT (NOT TRUE); NOT (NOT TRUE);
TRUE; TRUE;

View file

@ -287,3 +287,31 @@ SELECT
"fffffff" "fffffff"
) )
); );
/*
multi
line
comment
*/
SELECT * FROM foo;
/*
multi
line
comment
*/
SELECT
*
FROM foo;
SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
SELECT
x
FROM a.b.c /* x */, e.f.g /* x */;
SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
SELECT
x
FROM (
SELECT
*
FROM bla /* x */
WHERE
id = 1
) /* x */;

View file

@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2", "SELECT x FROM tbl LEFT OUTER JOIN tbl2",
), ),
( (
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"), lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2", "SELECT x FROM tbl LEFT OUTER JOIN tbl2",
), ),
( (
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
), ),
( (
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"), lambda: select("x")
.from_("tbl")
.join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
), ),
( (
@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
), ),
( (
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"), lambda: select("x")
.from_("tbl")
.join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b", "SELECT x FROM tbl LEFT JOIN x ON a = b",
), ),
( (
@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b", "SELECT x FROM tbl LEFT JOIN x ON a = b",
), ),
( (
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"), lambda: select("x")
.from_("tbl")
.join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
), ),
( (
@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
( (
lambda: select("x", "y", "z") lambda: select("x", "y", "z")
.from_("merged_df") .from_("merged_df")
.join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]), .join(
"vte_diagnosis_df",
using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
),
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
), ),
( (
@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
), ),
( (
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"), lambda: select("x", "y", "z", "a")
.from_("tbl")
.cluster_by("x, y", "z")
.cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
), ),
( (
@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True), lambda: select("x")
.from_("tbl")
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), lambda: select("x")
.from_("tbl")
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
), ),
( (
@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
), ),
( (
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"), lambda: select("x")
.from_("tbl")
.with_("tbl", as_=select("x", "y").from_("tbl2"))
.select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
), ),
( (
@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
), ),
( (
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"), lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
), ),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
), ),
( (
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"), lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
"x"
),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
), ),
( (

View file

@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
) )
cls.cache = {} cls.cache = {}
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")] cls.sqls = [
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self): def test_execute_tpch(self):
def to_csv(expression): def to_csv(expression):
if isinstance(expression, exp.Table): if isinstance(expression, exp.Table):
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}") return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
)
return expression return expression
for sql, _ in self.sqls[0:3]: for sql, _ in self.sqls[0:3]:

View file

@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)")) self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table()) self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table()) self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)) self.assertEqual(
exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
)
def test_find(self): def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join)) self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self): def test_alias_or_name(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual( self.assertEqual(
[e.alias_or_name for e in expression.expressions], [e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"], ["a", "B", "e", "*", "zz", "z"],
@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM foo WHERE ? > 100", "SELECT * FROM foo WHERE ? > 100",
) )
self.assertEqual( self.assertEqual(
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(), exp.replace_placeholders(
parse_one("select * from :name WHERE ? > 100"), another_name="bla"
).sql(),
"SELECT * FROM :name WHERE ? > 100", "SELECT * FROM :name WHERE ? > 100",
) )
self.assertEqual( self.assertEqual(
@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
) )
def test_named_selects(self): def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz") expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one( expression = parse_one(
@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(len(list(expression.walk())), 9) self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9) self.assertEqual(len(list(expression.walk(bfs=False))), 9)
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk())) self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))) self.assertTrue(
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
)
def test_functions(self): def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
), ),
exp.Properties( exp.Properties(
expressions=[ expressions=[
exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")), exp.FileFormatProperty(
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
),
exp.PartitionedByProperty( exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"), this=exp.Literal.string("PARTITIONED_BY"),
value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]), value=exp.Tuple(
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
),
),
exp.AnonymousProperty(
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
), ),
exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.TableFormatProperty( exp.TableFormatProperty(
this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format") this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
), ),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
((1, "2", None), "(1, '2', NULL)"), ((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"), ({"x": None}, "MAP('x', NULL)"),
(datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), (
datetime.datetime(2022, 10, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
),
( (
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
with self.subTest(value): with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected) self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self): def test_comment_alias(self):
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo" sql = """
SELECT
a,
b AS B,
c, /*comment*/
d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo
"""
expression = parse_one(sql) expression = parse_one(sql)
self.assertEqual( self.assertEqual(
[e.alias_or_name for e in expression.expressions], [e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"], ["a", "B", "c", "D", "x"],
) )
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual( self.assertEqual(
expression.sql(pretty=True, annotations=False), expression.sql(),
"SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
)
self.assertEqual(
expression.sql(comments=False),
"SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
)
self.assertEqual(
expression.sql(pretty=True, comments=False),
"""SELECT """SELECT
a, a,
b AS B, b AS B,
c, c,
d AS D""", d AS D,
CAST(x AS INT)
FROM foo""",
) )
self.assertEqual( self.assertEqual(
expression.sql(pretty=True), expression.sql(pretty=True),
"""SELECT """SELECT
a, a,
b AS B, b AS B,
c # comment, c, -- comment
d AS D # another_comment FROM foo""", d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo""",
) )
def test_to_table(self): def test_to_table(self):
@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(expression, exp.Union) self.assertIsInstance(expression, exp.Union)
self.assertEqual(expression.named_selects, ["cola", "colb"]) self.assertEqual(expression.named_selects, ["cola", "colb"])
self.assertEqual( self.assertEqual(
expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))] expression.selects,
[
exp.Column(this=exp.to_identifier("cola")),
exp.Column(this=exp.to_identifier("colb")),
],
) )

View file

@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
} }
def check_file(self, file, func, pretty=False, execute=False, **kwargs): def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1): for i, (meta, sql, expected) in enumerate(
load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
):
title = meta.get("title") or f"{i}, {sql}" title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect") dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated") leave_tables_isolated = meta.get("leave_tables_isolated")
@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
if string_to_bool(should_execute): if string_to_bool(should_execute):
with self.subTest(f"(execute) {title}"): with self.subTest(f"(execute) {title}"):
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df() df1 = self.conn.execute(
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df() df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
assert_frame_equal(df1, df2) assert_frame_equal(df1, df2)
@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)") self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual( self.assertEqual(
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)" scopes[3].expression.sql(),
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
) )
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y") self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b") self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that we can walk in scope from an arbitrary node # Check that we can walk in scope from an arbitrary node
self.assertEqual( self.assertEqual(
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)}, {
node.sql()
for node, *_ in walk_in_scope(expression.find(exp.Where))
if isinstance(node, exp.Column)
},
{"s.b"}, {"s.b"},
) )
@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self): def test_cache_annotation(self):
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")) expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self): def test_binary_annotation(self):
@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
""" """
expression = annotate_types(parse_one(sql), schema=schema) expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col self.assertEqual(
expression.expressions[0].type, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
cte_select = expression.args["with"].expressions[0].this cte_select = expression.args["with"].expressions[0].this
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola self.assertEqual(
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]): for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
self.assertEqual(d.this.expressions[0].this.type, t) self.assertEqual(d.this.expressions[0].this.type, t)
def test_function_annotation(self): def test_function_annotation(self):
@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self): def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}} schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr = concat_expr_alias.this concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola) self.assertEqual(
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg) concat_expr.right.type, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self): def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this

View file

@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
def test_float(self): def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2")) self.assertEqual(parse_one(".2"), parse_one("0.2"))
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self): def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
def test_select(self): def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"]) self.assertIsNotNone(
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
)
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1) self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
self.assertEqual( self.assertEqual(
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
def test_var(self): def test_var(self):
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
def test_annotations(self): def test_comments(self):
expression = parse_one( expression = parse_one(
""" """
SELECT --comment1
a #annotation1, SELECT /* this won't be used */
b as B #annotation2:testing , a, --comment2
"test#annotation",c#annotation3, d #annotation4, b as B, --comment3:testing
e #, "test--annotation",
f # space c, --comment4 --foo
e, --
f -- space
FROM foo FROM foo
""" """
) )
assert expression.expressions[0].name == "annotation1" self.assertEqual(expression.comment, "comment1")
assert expression.expressions[1].name == "annotation2:testing" self.assertEqual(expression.expressions[0].comment, "comment2")
assert expression.expressions[2].name == "test#annotation" self.assertEqual(expression.expressions[1].comment, "comment3:testing")
assert expression.expressions[3].name == "annotation3" self.assertEqual(expression.expressions[2].comment, None)
assert expression.expressions[4].name == "annotation4" self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
assert expression.expressions[5].name == "" self.assertEqual(expression.expressions[4].comment, "")
assert expression.expressions[6].name == "space" self.assertEqual(expression.expressions[5].comment, " space")
def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
self.assertEqual(
parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
)
self.assertEqual(
parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
)
self.assertEqual(
parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP)",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP(1))",
)
self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
self.assertIsInstance(parse_one("map.x"), exp.Column)
def test_pretty_config_override(self): def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")

View file

@ -1,281 +1,141 @@
import unittest import unittest
from sqlglot import table from sqlglot import exp, to_table
from sqlglot.dataframe.sql import types as df_types from sqlglot.errors import SchemaError
from sqlglot.schema import MappingSchema, ensure_schema from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase): class TestSchema(unittest.TestCase):
def assert_column_names(self, schema, *table_results):
for table, result in table_results:
with self.subTest(f"{table} -> {result}"):
self.assertEqual(schema.column_names(to_table(table)), result)
def assert_column_names_raises(self, schema, *tables):
for table in tables:
with self.subTest(table):
with self.assertRaises(SchemaError):
schema.column_names(to_table(table))
def test_schema(self): def test_schema(self):
schema = ensure_schema( schema = ensure_schema(
{ {
"x": { "x": {
"a": "uint64", "a": "uint64",
} },
} "y": {
"b": "uint64",
"c": "uint64",
},
},
) )
self.assertEqual(
schema.column_names( self.assert_column_names(
table( schema,
("x", ["a"]),
("y", ["b", "c"]),
("z.x", ["a"]),
("z.x.y", ["b", "c"]),
)
self.assert_column_names_raises(
schema,
"z",
"z.z",
"z.z.z",
)
def test_schema_db(self):
schema = ensure_schema(
{
"d1": {
"x": {
"a": "uint64",
},
"y": {
"b": "uint64",
},
},
"d2": {
"x": {
"c": "uint64",
},
},
},
)
self.assert_column_names(
schema,
("d1.x", ["a"]),
("d2.x", ["c"]),
("y", ["b"]),
("d1.y", ["b"]),
("z.d1.y", ["b"]),
)
self.assert_column_names_raises(
schema,
"x", "x",
) "z.x",
), "z.y",
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y"), {"b": "string"})
schema_with_y = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
self.assertEqual(schema.schema, schema_with_y)
new_schema = schema.copy()
new_schema.add_table(table("z"), {"c": "string"})
self.assertEqual(schema.schema, schema_with_y)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"z": {
"c": "string",
},
},
)
schema.add_table(table("m"), {"d": "string"})
schema.add_table(table("n"), {"e": "string"})
schema_with_m_n = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
}
self.assertEqual(schema.schema, schema_with_m_n)
new_schema = schema.copy()
new_schema.add_table(table("o"), {"f": "string"})
new_schema.add_table(table("p"), {"g": "string"})
self.assertEqual(schema.schema, schema_with_m_n)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
"o": {
"f": "string",
},
"p": {
"g": "string",
},
},
) )
def test_schema_catalog(self):
schema = ensure_schema( schema = ensure_schema(
{ {
"db": { "c1": {
"x": { "d1": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("y"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y", db="db"), {"b": "string"})
self.assertEqual(
schema.schema,
{
"db": {
"x": { "x": {
"a": "uint64", "a": "uint64",
}, },
"y": { "y": {
"b": "string", "b": "uint64",
}, },
}
},
)
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("x"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("x", db="db"), {"b": "string"})
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
}
}
},
)
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": { "z": {
"c": "string", "c": "uint64",
"d": "int",
}
}, },
}
},
)
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
}, },
}, },
"c2": { "c2": {
"db2": { "d1": {
"m": {
"e": "string",
"f": "int",
}
}
},
},
)
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(schema.column_names(table("x")), ["a"])
schema = MappingSchema()
schema.add_table(table("x"), {"a": "string"})
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
}
},
)
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
},
"y": { "y": {
"b": "string", "d": "uint64",
},
"z": {
"e": "uint64",
}, },
}, },
"d2": {
"z": {
"f": "uint64",
},
},
},
}
)
self.assert_column_names(
schema,
("x", ["a"]),
("d1.x", ["a"]),
("c1.d1.x", ["a"]),
("c1.d1.y", ["b"]),
("c1.d1.z", ["c"]),
("c2.d1.y", ["d"]),
("c2.d1.z", ["e"]),
("d2.z", ["f"]),
("c2.d2.z", ["f"]),
)
self.assert_column_names_raises(
schema,
"q",
"d2.x",
"y",
"z",
"d1.y",
"d1.z",
"a.b.c",
) )
def test_schema_add_table_with_and_without_mapping(self): def test_schema_add_table_with_and_without_mapping(self):
@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
self.assertEqual(schema.column_names("test"), ["x", "y"]) self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test") schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"]) self.assertEqual(schema.column_names("test"), ["x", "y"])
def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}})
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
exp.DataType.Type.VARCHAR,
)

18
tests/test_tokens.py Normal file
View file

@ -0,0 +1,18 @@
import unittest
from sqlglot.tokens import Tokenizer
class TestTokens(unittest.TestCase):
def test_comment_attachment(self):
tokenizer = Tokenizer()
sql_comment = [
("/*comment*/ foo", "comment"),
("/*comment*/ foo --test", "comment"),
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
]
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)

View file

@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
leading_comma=True, leading_comma=True,
pretty=True, pretty=True,
) )
self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
leading_comma=True,
pretty=True,
)
# without pretty, this should be a no-op # without pretty, this should be a no-op
self.validate( self.validate(
"SELECT FOO, BAR, BAZ", "SELECT FOO, BAR, BAZ",
@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3") self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self): def test_comments(self):
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo") self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
self.validate(
"SELECT 1 /* inline */ FROM foo -- comment",
"SELECT 1 /* inline */ FROM foo /* comment */",
)
self.validate(
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
)
self.validate( self.validate(
""" """
SELECT 1 -- comment SELECT 1 -- comment
FROM foo -- comment FROM foo -- comment
""", """,
"SELECT 1 FROM foo", "SELECT 1 /* comment */ FROM foo /* comment */",
) )
self.validate( self.validate(
""" """
SELECT 1 /* big comment SELECT 1 /* big comment
like this */ like this */
FROM foo -- comment FROM foo -- comment
""", """,
"SELECT 1 FROM foo", """SELECT 1 /* big comment
like this */ FROM foo /* comment */""",
)
self.validate(
"select x from foo -- x",
"SELECT x FROM foo /* x */",
)
self.validate(
"""
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), # comment 3
y -- comment 4
FROM
bar /* comment 5 */,
tbl # comment 6
""",
"""/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql",
pretty=True,
) )
def test_types(self): def test_types(self):
@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
def test_ignore_nulls(self): def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_with(self):
self.validate(
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
self.validate(
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
def test_time(self): def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")