1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,6 +1,29 @@
Changelog
=========
v10.0.0
------
Changes:
- Breaking: replaced SQLGlot annotations with comments. Now comments can be preserved after transpilation, and they can appear in other places besides SELECT's expressions.
- Breaking: renamed list_get to seq_get.
- Breaking: activated mypy type checking for SQLGlot.
- New: Azure Databricks support.
- New: placeholders can now be replaced in an expression.
- New: null safe equal operator (<=>).
- New: [SET statements](https://github.com/tobymao/sqlglot/pull/673) for MySQL.
- New: [SHOW commands](https://dev.mysql.com/doc/refman/8.0/en/show.html) for MySQL.
- New: [FORMAT function](https://www.w3schools.com/sql/func_sqlserver_format.asp) for TSQL.
- New: CROSS APPLY / OUTER APPLY [support](https://github.com/tobymao/sqlglot/pull/641) for TSQL.
- New: added formats for TSQL's [DATENAME/DATEPART functions](https://learn.microsoft.com/en-us/sql/t-sql/functions/datename-transact-sql?view=sql-server-ver16)
- New: added styles for TSQL's [CONVERT function](https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16).
- Improvement: [refactored the schema](https://github.com/tobymao/sqlglot/pull/668) to be more lenient; before it needed to do an exact match of db.table, now it finds table if there are no ambiguities.
- Improvement: allow functions to [inherit](https://github.com/tobymao/sqlglot/pull/674) their arguments' types, so that annotating CASE, IF etc. is possible.
- Improvement: allow [joining with same names](https://github.com/tobymao/sqlglot/pull/660) in the python executor.
- Improvement: the "using" field can now be set for the [join expression builders](https://github.com/tobymao/sqlglot/pull/636).
- Improvement: qualify_columns [now qualifies](https://github.com/tobymao/sqlglot/pull/635) only non-alias columns in the having clause.
v9.0.0
------

View file

@ -14,7 +14,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [Install](#install)
* [Documentation](#documentation)
* [Run Tests & Lint](#run-tests-and-lint)
* [Run Tests and Lint](#run-tests-and-lint)
* [Examples](#examples)
* [Formatting and Transpiling](#formatting-and-transpiling)
* [Metadata](#metadata)
@ -22,7 +22,6 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/
* [Unsupported Errors](#unsupported-errors)
* [Build and Modify SQL](#build-and-modify-sql)
* [SQL Optimizer](#sql-optimizer)
* [SQL Annotations](#sql-annotations)
* [AST Introspection](#ast-introspection)
* [AST Diff](#ast-diff)
* [Custom Dialects](#custom-dialects)
@ -51,7 +50,7 @@ pip3 install -r dev-requirements.txt
## Documentation
SQLGlot's uses [pdocs](https://pdoc.dev/) to serve its API documentation:
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
```
pdoc sqlglot --docformat google
@ -121,6 +120,39 @@ LEFT JOIN `baz`
ON `f`.`a` = `baz`.`a`
```
Comments are also preserved in a best-effort basis when transpiling SQL code:
```python
sql = """
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), # comment 3
y -- comment 4
FROM
bar /* comment 5 */,
tbl # comment 6
"""
print(sqlglot.transpile(sql, read='mysql', pretty=True)[0])
```
```sql
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6*/
```
### Metadata
You can explore SQL with expression helpers to do things like find columns and tables:
@ -249,17 +281,6 @@ WHERE
"x"."Z" = CAST('2021-02-01' AS DATE)
```
### SQL Annotations
SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example:
```sql
SELECT
user # primary_key,
country
FROM users
```
### AST Introspection
You can see the AST version of the sql by calling `repr`:

View file

@ -1,15 +1,8 @@
#!/bin/bash -e
[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check'
python -m autoflake -i -r ${RETURN_ERROR_CODE} \
--expand-star-imports \
--remove-all-unused-imports \
--ignore-init-module-imports \
--remove-duplicate-keys \
--remove-unused-variables \
sqlglot/ tests/
python -m isort --profile black sqlglot/ tests/
python -m black ${RETURN_ERROR_CODE} --line-length 120 sqlglot/ tests/
python -m mypy sqlglot tests
TARGETS="sqlglot/ tests/"
python -m mypy $TARGETS
python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS
python -m isort $TARGETS
python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS
python -m unittest

View file

@ -3,7 +3,7 @@ disallow_untyped_calls = False
no_implicit_optional = True
[mypy-sqlglot.*]
ignore_errors = True
ignore_errors = False
[mypy-sqlglot.dataframe.*]
ignore_errors = False
@ -13,3 +13,16 @@ ignore_errors = True
[mypy-tests.dataframe.*]
ignore_errors = False
[autoflake]
in-place = True
expand-star-imports = True
remove-all-unused-imports = True
ignore-init-module-imports = True
remove-duplicate-keys = True
remove-unused-variables = True
quiet = True
[isort]
profile=black
known_first_party=sqlglot

View file

@ -21,6 +21,7 @@ setup(
author_email="toby.mao@gmail.com",
license="MIT",
packages=find_packages(include=["sqlglot", "sqlglot.*"]),
package_data={"sqlglot": ["py.typed"]},
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",

View file

@ -1,5 +1,9 @@
"""## Python SQL parser, transpiler and optimizer."""
from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
@ -20,51 +24,54 @@ from sqlglot.expressions import (
subquery,
)
from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "9.0.6"
__version__ = "10.0.1"
pretty = False
schema = MappingSchema()
def parse(sql, read=None, **opts):
def parse(
sql: str, read: t.Optional[str | Dialect] = None, **opts
) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per
parsed SQL statement.
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
Args:
sql (str): the SQL code string to parse.
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
**opts: other options.
Returns:
typing.List[Expression]: the list of parsed syntax trees.
The resulting syntax tree collection.
"""
dialect = Dialect.get_or_raise(read)()
return dialect.parse(sql, **opts)
def parse_one(sql, read=None, into=None, **opts):
def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
into: t.Optional[Expression | str] = None,
**opts,
) -> t.Optional[Expression]:
"""
Parses the given SQL string and returns a syntax tree for the first
parsed SQL statement.
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
Args:
sql (str): the SQL code string to parse.
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
into (Expression): the SQLGlot Expression to parse into
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
into: the SQLGlot Expression to parse into.
**opts: other options.
Returns:
Expression: the syntax tree for the first parsed statement.
The syntax tree for the first parsed statement.
"""
dialect = Dialect.get_or_raise(read)()
@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts):
return result[0] if result else None
def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
def transpile(
sql: str,
read: t.Optional[str | Dialect] = None,
write: t.Optional[str | Dialect] = None,
identity: bool = True,
error_level: t.Optional[ErrorLevel] = None,
**opts,
) -> t.List[str]:
"""
Parses the given SQL string using the source dialect and returns a list of SQL strings
transformed to conform to the target dialect. Each string in the returned list represents
a single transformed SQL statement.
Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.
Args:
sql (str): the SQL code string to transpile.
read (str): the source dialect used to parse the input string
(eg. "spark", "hive", "presto", "mysql").
write (str): the target dialect into which the input should be transformed
(eg. "spark", "hive", "presto", "mysql").
identity (bool): if set to True and if the target dialect is not specified
the source dialect will be used as both: the source and the target dialect.
error_level (ErrorLevel): the desired error level of the parser.
sql: the SQL code string to transpile.
read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
the source and the target dialect.
error_level: the desired error level of the parser.
**opts: other options.
Returns:
typing.List[str]: the list of transpiled SQL statements / expressions.
The list of transpiled SQL statements.
"""
write = write or read if identity else write
return [

View file

@ -49,7 +49,10 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse:
sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
sqls = [
repr(expression)
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
]
else:
sqls = sqlglot.transpile(
args.sql,

View file

@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
"ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
"ColumnLiterals",
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
"ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
"ColumnOrLiteral",
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
SchemaInput = t.TypeVar(
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
)
OutputExpressionContainer = t.TypeVar(
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)
SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])

View file

@ -18,7 +18,11 @@ class Column:
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")
self.expression: exp.Expression = expression
def __repr__(self):
return repr(self.expression)
@ -135,21 +139,29 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
k: [Column.ensure_col(x).expression for x in v]
if is_iterable(v)
else Column.ensure_col(v).expression
for k, v in kwargs.items()
}
new_expression = (
callable_expression(**ensure_expression_values)
if ensured_column is None
else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
else callable_expression(
this=ensured_column.column_expression, **ensure_expression_values
)
)
return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
return Column(
klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
)
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
return Column(
klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
)
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs))
@ -188,7 +200,7 @@ class Column:
expression.set("table", exp.to_identifier(table_name))
return Column(expression)
def sql(self, **kwargs) -> Column:
def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column:
@ -265,10 +277,14 @@ class Column:
)
def like(self, other: str):
return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
return self.invoke_expression_over_column(
self, exp.Like, expression=self._lit(other).expression
)
def ilike(self, other: str):
return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
return self.invoke_expression_over_column(
self, exp.ILike, expression=self._lit(other).expression
)
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
@ -287,10 +303,18 @@ class Column:
lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral],
) -> Column:
lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
lower_bound_exp = (
self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
)
upper_bound_exp = (
self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
)
return Column(
exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
exp.Between(
this=self.column_expression,
low=lower_bound_exp.expression,
high=upper_bound_exp.expression,
)
)
def over(self, window: WindowSpec) -> Column:

View file

@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
from sqlglot.dataframe.sql._typing import (
ColumnLiterals,
ColumnOrLiteral,
ColumnOrName,
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession
@ -83,7 +88,9 @@ class DataFrame:
return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias)
if not table_alias:
raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
raise RuntimeError(
f"Could not find an alias name for this expression: {self.expression}"
)
return table_alias.alias_or_name
return self.expression.ctes[-1].alias
@ -132,12 +139,16 @@ class DataFrame:
cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name
def _ensure_list_of_columns(
self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
) -> t.List[Column]:
columns = ensure_list(cols)
columns = Column.ensure_cols(columns)
return columns
@t.overload
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
...
@t.overload
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols)
@ -153,10 +164,16 @@ class DataFrame:
df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
cte_expression, cte_name = df._create_cte_from_expression(
expression=expression, sequence_id=sequence_id
)
new_expression = df._add_ctes_to_expression(
exp.Select(), expression.ctes + [cte_expression]
)
sel_columns = df._get_outer_select_columns(cte_expression)
new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
new_expression = new_expression.from_(cte_name).select(
*[x.alias_or_name for x in sel_columns]
)
return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame:
@ -169,16 +186,23 @@ class DataFrame:
hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint)
join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
join_aliases = {
join_table.alias_or_name
for join_table in get_tables_from_expression_with_join(expression)
}
if join_aliases:
for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
sequence_id_or_name
]
matching_ctes = [
cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
cte
for cte in reversed(expression.ctes)
if cte.args["sequence_id"] in sequence_ids_to_match
]
for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases:
@ -193,9 +217,14 @@ class DataFrame:
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper()
hint_expression = (
exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
exp.JoinHint(
this=hint_name,
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
)
if hint_name in JOIN_HINTS
else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
else exp.Anonymous(
this=hint_name, expressions=[parameter.expression for parameter in args]
)
)
new_df = self.copy()
new_df.pending_hints.append(hint_expression)
@ -245,7 +274,9 @@ class DataFrame:
def _get_select_expressions(
self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
select_expressions: t.List[
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
] = []
main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level")
@ -279,14 +310,19 @@ class DataFrame:
cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"]
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level),
]
expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
expression = exp.Cache(
this=cache_table, expression=select_expression, lazy=True, options=options
)
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@ -305,7 +341,9 @@ class DataFrame:
raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression)
return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
return [
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@ -317,7 +355,9 @@ class DataFrame:
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols:
join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
join_table_identifiers = [
x.this for x in get_tables_from_expression_with_join(self.expression)
]
cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
@ -367,14 +407,20 @@ class DataFrame:
@operation(Operation.FROM)
def join(
self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
self,
other_df: DataFrame,
on: t.Union[str, t.List[str], Column, t.List[Column]],
how: str = "inner",
**kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
join_columns = [
Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
]
join_clause = functools.reduce(
lambda x, y: x & y,
[
@ -402,7 +448,9 @@ class DataFrame:
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
column.alias_or_name
if not isinstance(column.expression.this, exp.Star)
else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
@ -410,16 +458,22 @@ class DataFrame:
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
expression=self.expression.join(
other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
)
)
new_df.expression = new_df._add_ctes_to_expression(
new_df.expression, other_df.expression.ctes
)
new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df
@operation(Operation.ORDER_BY)
def orderBy(
self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
self,
*cols: t.Union[str, Column],
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
) -> DataFrame:
"""
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
@ -429,7 +483,10 @@ class DataFrame:
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
x
for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
for x in [
i if isinstance(col.expression, exp.Ordered) else None
for i, col in enumerate(columns)
]
if x is not None
]
if ascending is None:
@ -478,7 +535,9 @@ class DataFrame:
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column)
r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
r_df = (
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
)
l_df = self.copy()
if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
@ -536,7 +595,9 @@ class DataFrame:
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
)
if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
if_null_checks = [
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True)
@ -576,11 +637,15 @@ class DataFrame:
value_columns = [lit(value) for value in values]
null_replacement_mapping = {
column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
column.alias_or_name: (
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
)
for column, value in zip(columns, value_columns)
}
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
null_replacement_columns = [
null_replacement_mapping[column.alias_or_name] for column in all_columns
]
new_df = new_df.select(*null_replacement_columns)
return new_df
@ -589,12 +654,11 @@ class DataFrame:
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
subset: t.Optional[t.Union[str, t.List[str]]] = None,
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit
old_values = None
subset = ensure_list(subset)
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
@ -605,7 +669,9 @@ class DataFrame:
new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list"
assert len(to_replace) == len(value), "the replacements and values must be the same length"
assert len(to_replace) == len(
value
), "the replacements and values must be the same length"
old_values = to_replace
new_values = value
else:
@ -635,7 +701,9 @@ class DataFrame:
def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects
existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
existing_col_index = (
existing_col_names.index(colName) if colName in existing_col_names else None
)
if existing_col_index:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression
@ -645,7 +713,11 @@ class DataFrame:
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy()
existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
existing_columns = [
expression
for expression in expression.expressions
if expression.alias_or_name == existing
]
if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
@ -674,15 +746,19 @@ class DataFrame:
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters)
parameter_columns = (
self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
self._ensure_list_of_columns(parameter_list)
if parameters
else Column.ensure_cols([self.sequence_id])
)
return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
num_partitions = Column.ensure_cols(ensure_list(numPartitions))
def repartition(
self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
) -> DataFrame:
num_partition_cols = self._ensure_list_of_columns(numPartitions)
columns = self._ensure_and_normalize_cols(cols)
args = num_partitions + columns
args = num_partition_cols + columns
return self._hint("repartition", args)
@operation(Operation.NO_OP)

View file

@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
return Column(
glotexp.Case(
ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
)
)
def asc(col: ColumnOrName) -> Column:
@ -407,7 +411,9 @@ def percentile_approx(
return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage)
)
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FACTORIAL")
def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
def lag(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1:
@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
return Column.invoke_anonymous_function(col, "LAG")
def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
def lead(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1:
@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
return Column.invoke_anonymous_function(col, "LEAD")
def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
def nth_value(
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
) -> Column:
if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1:
@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
def months_between(
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
) -> Column:
if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
def unix_timestamp(
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
) -> Column:
if format is not None:
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
return Column.invoke_expression_over_column(
timestamp, glotexp.StrToUnix, format=lit(format)
)
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
@ -642,7 +660,9 @@ def window(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
)
if slideDuration is not None:
return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
)
if startTime is not None:
return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
return Column.invoke_expression_over_column(
None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
)
def decode(col: ColumnOrName, charset: str) -> Column:
@ -768,7 +790,9 @@ def overlay(
def sentences(
string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
string: ColumnOrName,
language: t.Optional[ColumnOrName] = None,
country: t.Optional[ColumnOrName] = None,
) -> Column:
if language is not None and country is not None:
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
return Column.invoke_expression_over_column(
str, glotexp.StrPosition, substr=substr_col, position=pos
)
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
None,
glotexp.VarMap,
keys=array(*cols[::2]).expression,
values=array(*cols[1::2]).expression,
)
@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
return Column.invoke_expression_over_column(
col, glotexp.ArrayContains, expression=value_col.expression
)
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
def slice(
x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
) -> Column:
start_col = start if isinstance(start, Column) else lit(start)
length_col = length if isinstance(length, Column) else lit(length)
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
def array_join(
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
) -> Column:
if null_replacement is not None:
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
return Column.invoke_anonymous_function(
col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
)
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT")
return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
return Column.invoke_anonymous_function(
cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
def sequence(
start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
) -> Column:
if step is not None:
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
@ -1103,12 +1144,15 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
return Column.invoke_anonymous_function(
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
)
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform(
col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
col: ColumnOrName,
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
def filter(
col: ColumnOrName,
f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
def zip_with(
left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
variables = [
glotexp.to_identifier(x, quoted=_lambda_quoted(x))
for x in lambda_expression.__code__.co_varnames
]
return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,

View file

@ -17,7 +17,9 @@ class GroupedData:
self.last_op = last_op
self.group_by_cols = group_by_cols
def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
def _get_function_applied_columns(
self, func_name: str, cols: t.Tuple[str, ...]
) -> t.List[Column]:
func_name = func_name.lower()
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
@ -30,9 +32,9 @@ class GroupedData:
)
cols = self._df._ensure_and_normalize_cols(columns)
expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
*[x.expression for x in self.group_by_cols + cols], append=False
)
expression = self._df.expression.group_by(
*[x.expression for x in self.group_by_cols]
).select(*[x.expression for x in self.group_by_cols + cols], append=False)
return self._df.copy(expression=expression)
def count(self) -> DataFrame:

View file

@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
def replace_alias_name_with_cte_name(
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
join_table_aliases = [
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
]
ctes_in_join = [
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
values = ensure_list(values)
results = []
for value in values:
if isinstance(value, str):

View file

@ -19,12 +19,19 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
return DataFrame(
self.spark,
exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
)
class DataFrameWriter:
def __init__(
self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
self,
df: DataFrame,
spark: t.Optional[SparkSession] = None,
mode: t.Optional[str] = None,
by_name: bool = False,
):
self._df = df
self._spark = spark or df.spark
@ -33,7 +40,10 @@ class DataFrameWriter:
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
**{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
**{
k[1:] if k.startswith("_") else k: v
for k, v in object_to_dict(self, **kwargs).items()
}
)
def sql(self, **kwargs) -> t.List[str]:

View file

@ -67,13 +67,20 @@ class SparkSession:
data_expressions = [
exp.Tuple(
expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
expressions=list(
map(
lambda x: F.lit(x).expression,
row if not isinstance(row, dict) else row.values(),
)
)
)
for row in data
]
sel_columns = [
F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
F.col(name).cast(data_type).alias(name).expression
if data_type is not None
else F.col(name).expression
for name, data_type in column_mapping.items()
]
@ -106,10 +113,12 @@ class SparkSession:
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
df = DataFrame(self, select_expression, output_expression_container=expression)
df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
raise ValueError(
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
)
return df
@property

View file

@ -158,7 +158,11 @@ class MapType(DataType):
class StructField(DataType):
def __init__(
self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
self,
name: str,
dataType: DataType,
nullable: bool = True,
metadata: t.Optional[t.Dict[str, t.Any]] = None,
):
self.name = name
self.dataType = dataType

View file

@ -74,8 +74,13 @@ class WindowSpec:
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
def _calc_start_end(
self, start: int, end: int
) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
"start_side": None,
"end_side": None,
}
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
@ -83,7 +88,9 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
"start": "UNBOUNDED"
if start <= Window.unboundedPreceding
else F.lit(start).expression,
},
}
if end == Window.currentRow:
@ -93,7 +100,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
"end": "UNBOUNDED"
if end >= Window.unboundedFollowing
else F.lit(end).expression,
},
}
return kwargs
@ -103,7 +112,10 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
"spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
)
return window_spec
@ -112,6 +124,9 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
"spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
)
return window_spec

View file

@ -1,21 +1,21 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
interval = seq_get(args, 1)
return expression_class(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
@ -23,6 +23,13 @@ def _date_add(expression_class):
return func
def _date_trunc(args):
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(data_type, kind):
def func(self, expression):
this = self.sql(expression, "this")
@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
@ -89,18 +97,19 @@ class BigQuery(Dialect):
"%j": "%-j",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
@ -111,35 +120,40 @@ class BigQuery(Dialect):
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
KEYWORDS.pop("DIV")
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
}
NO_PAREN_FUNCTIONS = {
**Parser.NO_PAREN_FUNCTIONS,
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS,
*parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
class Generator(Generator):
class Generator(generator.Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -148,6 +162,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -157,11 +172,13 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",

View file

@ -1,8 +1,9 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.generator import Generator
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
def _lower_func(sql):
@ -14,11 +15,12 @@ class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
@ -30,9 +32,9 @@ class ClickHouse(Dialect):
"TUPLE": TokenType.STRUCT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"MAP": parse_var_map,
}
@ -44,11 +46,11 @@ class ClickHouse(Dialect):
return this
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@ -63,7 +65,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
@ -15,7 +17,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
**Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import flatten, list_get
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type):
classes = {}
classes: t.Dict[str, Dialect] = {}
@classmethod
def __getitem__(cls, key):
@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
klass.identifier_start, klass.identifier_end = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
normalize_functions = "upper"
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
time_mapping = {}
time_mapping: t.Dict[str, str] = {}
# autofilled
quote_start = None
@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPE,
"escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression):
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args):
return exp_class(
this=list_get(args, 0),
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
expression.set("this", schema)
return self.create_sql(expression)
@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
this = list_get(args, 2) if unit_based else list_get(args, 0)
expression = list_get(args, 1) if unit_based else list_get(args, 1)
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
def _sort_array_reverse(args):
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression):
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})"
@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
class DuckDB(Dialect):
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@ -92,7 +95,7 @@ class DuckDB(Dialect):
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
),
@ -112,11 +115,11 @@ class DuckDB(Dialect):
"UNNEST": exp.Explode.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@ -160,7 +163,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
var_map_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
int(expression.text("expression")) * multiplier
if expression.expression.is_number
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
@ -165,10 +167,10 @@ class Hive(Dialect):
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
ESCAPE = "\\"
ESCAPES = ["\\"]
ENCODE = "utf-8"
NUMERIC_LITERALS = {
@ -180,40 +182,44 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
class Parser(Parser):
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"),
),
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=list_get(args, 0)),
expression=exp.TsOrDsToDate(this=list_get(args, 1)),
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=exp.Mul(
this=list_get(args, 1),
this=seq_get(args, 1),
expression=exp.Literal.number(-1),
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
this=list_get(args, 1),
substr=list_get(args, 0),
position=list_get(args, 2),
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@ -226,15 +232,16 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),

View file

@ -1,4 +1,8 @@
from sqlglot import exp
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
no_ilike_sql,
@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _show_parser(*args, **kwargs):
def _parse(self):
return self._parse_show_mysql(*args, **kwargs)
return _parse
def _date_trunc_sql(self, expression):
unit = expression.text("unit").lower()
unit = expression.name.lower()
this = self.sql(expression.this)
expr = self.sql(expression.expression)
if unit == "day":
return f"DATE({this})"
return f"DATE({expr})"
if unit == "week":
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
concat = f"CONCAT(YEAR({this}), ' 1 1')"
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
return f"DATE({this})"
return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
date_format = MySQL.format_time(list_get(args, 1))
return exp.StrToDate(this=list_get(args, 0), format=date_format)
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
@ -66,9 +75,9 @@ def _trim_sql(self, expression):
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
interval = seq_get(args, 1)
return expression_class(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
@ -101,15 +110,16 @@ class MySQL(Dialect):
"%l": "%-I",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@ -156,20 +166,23 @@ class MySQL(Dialect):
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
"@@": TokenType.SESSION_PARAMETER,
}
class Parser(Parser):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@ -178,15 +191,212 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
class Generator(Generator):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
SHOW_PARSERS = {
"BINARY LOGS": _show_parser("BINARY LOGS"),
"MASTER LOGS": _show_parser("BINARY LOGS"),
"BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
"CHARACTER SET": _show_parser("CHARACTER SET"),
"CHARSET": _show_parser("CHARACTER SET"),
"COLLATION": _show_parser("COLLATION"),
"FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
"COLUMNS": _show_parser("COLUMNS", target="FROM"),
"CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
"CREATE EVENT": _show_parser("CREATE EVENT", target=True),
"CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
"CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
"CREATE TABLE": _show_parser("CREATE TABLE", target=True),
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
"DATABASES": _show_parser("DATABASES"),
"ENGINE": _show_parser("ENGINE", target=True),
"STORAGE ENGINES": _show_parser("ENGINES"),
"ENGINES": _show_parser("ENGINES"),
"ERRORS": _show_parser("ERRORS"),
"EVENTS": _show_parser("EVENTS"),
"FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
"FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
"GRANTS": _show_parser("GRANTS", target="FOR"),
"INDEX": _show_parser("INDEX", target="FROM"),
"MASTER STATUS": _show_parser("MASTER STATUS"),
"OPEN TABLES": _show_parser("OPEN TABLES"),
"PLUGINS": _show_parser("PLUGINS"),
"PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
"PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
"PRIVILEGES": _show_parser("PRIVILEGES"),
"FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
"PROCESSLIST": _show_parser("PROCESSLIST"),
"PROFILE": _show_parser("PROFILE"),
"PROFILES": _show_parser("PROFILES"),
"RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
"REPLICAS": _show_parser("REPLICAS"),
"SLAVE HOSTS": _show_parser("REPLICAS"),
"REPLICA STATUS": _show_parser("REPLICA STATUS"),
"SLAVE STATUS": _show_parser("REPLICA STATUS"),
"GLOBAL STATUS": _show_parser("STATUS", global_=True),
"SESSION STATUS": _show_parser("STATUS"),
"STATUS": _show_parser("STATUS"),
"TABLE STATUS": _show_parser("TABLE STATUS"),
"FULL TABLES": _show_parser("TABLES", full=True),
"TABLES": _show_parser("TABLES"),
"TRIGGERS": _show_parser("TRIGGERS"),
"GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
"SESSION VARIABLES": _show_parser("VARIABLES"),
"VARIABLES": _show_parser("VARIABLES"),
"WARNINGS": _show_parser("WARNINGS"),
}
SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
}
PROFILE_TYPES = {
"ALL",
"BLOCK IO",
"CONTEXT SWITCHES",
"CPU",
"IPC",
"MEMORY",
"PAGE FAULTS",
"SOURCE",
"SWAPS",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
target_id = self._parse_id_var()
else:
target_id = None
log = self._parse_string() if self._match_text("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
return self.expression(
exp.Show,
this=this,
target=target_id,
full=full,
log=log,
position=position,
db=db,
channel=channel,
like=like,
where=where,
types=types,
query=query,
offset=offset,
limit=limit,
mutex=mutex,
**{"global": global_},
)
def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
return exp.Var(this=type_)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
return offset, limit
def _default_parse_set_item(self):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
right = self._parse_statement() or self._parse_id_var()
this = self.expression(
exp.EQ,
this=left,
expression=right,
)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_charset(self, kind):
this = self._parse_string() or self._parse_id_var()
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
collate=collate,
kind="NAMES",
)
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
@ -199,6 +409,8 @@ class MySQL(Dialect):
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
}
ROOT_PROPERTIES = {
@ -209,4 +421,69 @@ class MySQL(Dialect):
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {}
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
global_ = " GLOBAL" if expression.args.get("global") else ""
target = self.sql(expression, "target")
target = f" {target}" if target else ""
if expression.name in {"COLUMNS", "INDEX"}:
target = f" FROM{target}"
elif expression.name == "GRANTS":
target = f" FOR{target}"
db = self._prefixed_sql("FROM", expression, "db")
like = self._prefixed_sql("LIKE", expression, "like")
where = self.sql(expression, "where")
types = self.expressions(expression, key="types")
types = f" {types}" if types else types
query = self._prefixed_sql("FOR QUERY", expression, "query")
if expression.name == "PROFILE":
offset = self._prefixed_sql("OFFSET", expression, "offset")
limit = self._prefixed_sql("LIMIT", expression, "limit")
else:
offset = ""
limit = self._oldstyle_limit_sql(expression)
log = self._prefixed_sql("IN", expression, "log")
position = self._prefixed_sql("FROM", expression, "position")
channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
if expression.name == "ENGINE":
mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
else:
mutex_or_status = ""
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
def _prefixed_sql(self, prefix, expression, arg):
sql = self.sql(expression, arg)
if not sql:
return ""
return f" {prefix} {sql}"
def _oldstyle_limit_sql(self, expression):
limit = self.sql(expression, "limit")
offset = self.sql(expression, "offset")
if limit:
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
def setitem_sql(self, expression):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -1,8 +1,9 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot.generator import Generator
from sqlglot.helper import csv
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
def _limit_sql(self, expression):
@ -36,9 +37,9 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@ -49,11 +50,12 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -86,9 +88,9 @@ class Oracle(Dialect):
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
@ -160,12 +160,12 @@ class Postgres(Dialect):
"YYYY": "%Y", # 2015
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
@ -179,31 +179,32 @@ class Postgres(Dialect):
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
class Parser(Parser):
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
@ -110,30 +110,29 @@ class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
time_mapping = MySQL.time_mapping
time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
**tokens.Tokenizer.KEYWORDS,
"ROW": TokenType.STRUCT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@ -143,7 +142,7 @@ class Presto(Dialect):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@ -159,7 +158,7 @@ class Presto(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@ -169,8 +168,8 @@ class Presto(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping,
**Postgres.time_mapping, # type: ignore
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
ESCAPE = "\\"
ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"VARBYTE": TokenType.BINARY,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
**Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
rename_func,
)
from sqlglot.expressions import Literal
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _check_int(s):
@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime(this=first_arg, scale=timescale)
first_arg = list_get(args, 0)
first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
def _unix_to_time(self, expression):
def _unix_to_time_sql(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@ -132,9 +134,9 @@ class Snowflake(Dialect):
"ff6": "%f",
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
@ -143,18 +145,18 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
*parser.Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
**parser.Parser.COLUMN_OPERATORS, # type: ignore
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@ -163,21 +165,21 @@ class Snowflake(Dialect):
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
ESCAPES = ["\\"]
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
@ -187,15 +189,15 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
class Generator(Generator):
class Generator(generator.Generator):
CREATE_TRANSIENT = True
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
@ -204,7 +206,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}

View file

@ -1,8 +1,9 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.helper import seq_get
def _create_sql(self, e):
@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=list_get(args, 0),
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=list_get(args, 1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=list_get(args, 0),
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=list_get(args, 1),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
**tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@ -46,6 +45,7 @@ class SQLite(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
TOKEN_MAPPING = {
@ -53,7 +53,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator):
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@ -13,7 +15,7 @@ class StarRocks(MySQL):
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
**MySQL.Generator.TRANSFORMS, # type: ignore
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
@ -22,3 +24,4 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
TRANSFORMS.pop(exp.DateTrunc)

View file

@ -1,7 +1,7 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.parser import Parser
def _if_sql(self, expression):
@ -20,17 +20,17 @@ def _count_sql(self, expression):
class Tableau(Dialect):
class Generator(Generator):
class Generator(generator.Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.presto import Presto
@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
**Presto.Generator.TRANSFORMS, # type: ignore
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}

View file

@ -1,15 +1,22 @@
from __future__ import annotations
import re
from sqlglot import exp
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
"w": "%A",
"month": "%B",
"mm": "%B",
"m": "%B",
}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=list_get(args, 1),
this=seq_get(args, 1),
format=exp.Literal.string(
format_time(
list_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
)
),
)
@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def parse_format(args):
fmt = list_get(args, 1)
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
return exp.NumberToStr(this=list_get(args, 0), format=fmt)
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.TimeToStr(
this=list_get(args, 0),
this=seq_get(args, 0),
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@ -188,11 +197,11 @@ class TSQL(Dialect):
"Y": "%a %Y",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
@ -200,7 +209,6 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
@ -213,9 +221,9 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -243,14 +251,16 @@ class TSQL(Dialect):
this = self._parse_column()
# Retrieve length of datatype and override to default if not specified
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping:
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val}"
)
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format
@ -272,9 +282,9 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
@ -283,7 +293,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),

View file

@ -4,7 +4,7 @@ from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_list
from sqlglot.helper import ensure_collection
@dataclass(frozen=True)
@ -116,7 +116,9 @@ class ChangeDistiller:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@ -158,13 +160,16 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
1 if s in source_leaf_ids and t in target_leaf_ids else 0
for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
adjusted_t = (
self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
)
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
@ -201,7 +206,10 @@ class ChangeDistiller:
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
if (
id(source_leaf) in self._unmatched_source_nodes
and id(target_leaf) in self._unmatched_target_nodes
):
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
@ -241,8 +249,7 @@ def _get_leaves(expression):
has_child_exprs = False
for a in expression.args.values():
nodes = ensure_list(a)
for node in nodes:
for node in ensure_collection(a):
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
@ -268,7 +275,7 @@ def _expression_only_args(expression):
args = []
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from enum import auto
from sqlglot.helper import AutoName
@ -30,7 +33,11 @@ class OptimizeError(SqlglotError):
pass
def concat_errors(errors, maximum):
class SchemaError(SqlglotError):
pass
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:

View file

@ -19,6 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@ -29,8 +30,27 @@ class Context:
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
@property
def table(self):
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
return self._table
@property
def columns(self):
return self.table.columns
def __iter__(self):
return self.table_iter(list(self.tables)[0])
self.env["scope"] = self.row_readers
for i in range(len(self.table.rows)):
for table in self.tables.values():
reader = table[i]
yield reader, self
def table_iter(self, table):
self.env["scope"] = self.row_readers
@ -38,8 +58,8 @@ class Context:
for reader in self.tables[table]:
yield reader, self
def sort(self, table, key):
table = self.tables[table]
def sort(self, key):
table = self.table
def sort_key(row):
table.reader.row = row
@ -47,20 +67,20 @@ class Context:
table.rows.sort(key=sort_key)
def set_row(self, table, row):
self.row_readers[table].row = row
def set_row(self, row):
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
def set_index(self, table, index):
self.row_readers[table].row = self.tables[table].rows[index]
def set_index(self, index):
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
def set_range(self, table, start, end):
self.range_readers[table].range = range(start, end)
def set_range(self, start, end):
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
def __getitem__(self, table):
return self.env["scope"][table]
def __contains__(self, table):
return table in self.tables

View file

@ -2,6 +2,8 @@ import datetime
import re
import statistics
from sqlglot.helper import PYTHON_VERSION
class reverse_key:
def __init__(self, obj):
@ -25,7 +27,7 @@ ENV = {
"str": str,
"desc": reverse_key,
"SUM": sum,
"AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
"AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max,
"MIN": min,

View file

@ -1,15 +1,14 @@
import ast
import collections
import itertools
import math
from sqlglot import exp, planner
from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
from sqlglot.generator import Generator
from sqlglot.helper import csv_reader
from sqlglot.tokens import Tokenizer
class PythonExecutor:
@ -26,7 +25,11 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
{name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
{
name: table
for dep in node.dependencies
for name, table in contexts[dep].tables.items()
}
)
running.add(node)
@ -76,13 +79,10 @@ class PythonExecutor:
return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context):
if hasattr(step, "source"):
source = step.source
if isinstance(source, exp.Expression):
source = source.name or source.alias
else:
source = step.name
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@ -96,14 +96,12 @@ class PythonExecutor:
if projections:
sink = self.table(step.projections)
elif source in context:
sink = Table(context[source].columns)
else:
sink = None
for reader, ctx in table_iter:
if sink is None:
sink = Table(ctx[source].columns)
sink = Table(reader.columns)
if condition and not ctx.eval(condition):
continue
@ -135,119 +133,102 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
yield context[alias], context
context.set_row(tuple(t(v) for t, v in zip(types, row)))
yield context.table.reader, context
def join(self, step, context):
source = step.name
join_context = self.context({source: context.tables[source]})
def merge_context(ctx, table):
# create a new context where all existing tables are mapped to a new one
return self.context({name: table for name in ctx.tables})
source_table = context.tables[source]
source_context = self.context({source: source_table})
column_ranges = {source: range(0, len(source_table.columns))}
for name, join in step.joins.items():
join_context = self.context({**join_context.tables, name: context.tables[name]})
table = context.tables[name]
start = max(r.stop for r in column_ranges.values())
column_ranges[name] = range(start, len(table.columns) + start)
join_context = self.context({name: table})
if join.get("source_key"):
table = self.hash_join(join, source, name, join_context)
table = self.hash_join(join, source_context, join_context)
else:
table = self.nested_loop_join(join, source, name, join_context)
table = self.nested_loop_join(join, source_context, join_context)
join_context = merge_context(join_context, table)
source_context = self.context(
{
name: Table(table.columns, table.rows, column_range)
for name, column_range in column_ranges.items()
}
)
# apply projections or conditions
context = self.scan(step, join_context)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
# use the scan context since it returns a single table
# otherwise there are no projections so all other tables are still in scope
if step.projections:
return context
if not condition or not projections:
return source_context
return merge_context(join_context, context.tables[source])
sink = self.table(step.projections if projections else source_context.columns)
def nested_loop_join(self, _join, a, b, context):
table = Table(context.tables[a].columns + context.tables[b].columns)
for reader, ctx in join_context:
if condition and not ctx.eval(condition):
continue
for reader_a, _ in context.table_iter(a):
for reader_b, _ in context.table_iter(b):
if projections:
sink.append(ctx.eval_tuple(projections))
else:
sink.append(reader.row)
if len(sink) >= step.limit:
break
return self.context({step.name: sink})
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
for reader_a, _ in source_context:
for reader_b, _ in join_context:
table.append(reader_a.row + reader_b.row)
return table
def hash_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])
results = collections.defaultdict(lambda: ([], []))
for reader, ctx in context.table_iter(a):
results[ctx.eval_tuple(a_key)][0].append(reader.row)
for reader, ctx in context.table_iter(b):
results[ctx.eval_tuple(b_key)][1].append(reader.row)
for reader, ctx in source_context:
results[ctx.eval_tuple(source_key)][0].append(reader.row)
for reader, ctx in join_context:
results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
table = Table(context.tables[a].columns + context.tables[b].columns)
for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def sort_merge_join(self, join, a, b, context):
a_key = self.generate_tuple(join["source_key"])
b_key = self.generate_tuple(join["join_key"])
context.sort(a, a_key)
context.sort(b, b_key)
a_i = 0
b_i = 0
a_n = len(context.tables[a])
b_n = len(context.tables[b])
table = Table(context.tables[a].columns + context.tables[b].columns)
def get_key(source, key, i):
context.set_index(source, i)
return context.eval_tuple(key)
while a_i < a_n and b_i < b_n:
key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
a_group = []
while a_i < a_n and key == get_key(a, a_key, a_i):
a_group.append(context[a].row)
a_i += 1
b_group = []
while b_i < b_n and key == get_key(b, b_key, b_i):
b_group.append(context[b].row)
b_i += 1
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
source = step.source
group_by = self.generate_tuple(step.group)
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
context.sort(source, group_by)
if step.operands:
if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
context = self.context({source: operand_table})
context = self.context(
{None: operand_table, **{table: operand_table for table in context.tables}}
)
context.sort(group_by)
group = None
start = 0
@ -256,15 +237,15 @@ class PythonExecutor:
table = self.table(step.group + step.aggregations)
for i in range(length):
context.set_index(source, i)
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
context.set_range(source, start, end - 1)
context.set_range(start, end - 1)
elif key != group:
context.set_range(source, start, end - 2)
context.set_range(start, end - 2)
else:
continue
@ -272,13 +253,32 @@ class PythonExecutor:
group = key
start = end - 2
return self.scan(step, self.context({source: table}))
context = self.context({step.name: table, **{name: table for name in context.tables}})
if step.projections:
return self.scan(step, context)
return context
def sort(self, step, context):
table = list(context.tables)[0]
key = self.generate_tuple(step.key)
context.sort(table, key)
return self.scan(step, context)
projections = self.generate_tuple(step.projections)
sink = self.table(step.projections)
for reader, ctx in context:
sink.append(ctx.eval_tuple(projections))
context = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
context.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
context.table.rows = context.table.rows[0 : step.limit]
return self.context({step.name: context.table})
def _cast_py(self, expression):
@ -293,7 +293,7 @@ def _cast_py(self, expression):
def _column_py(self, expression):
table = self.sql(expression, "table")
table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
@ -319,10 +319,10 @@ def _ordered_py(self, expression):
class Python(Dialect):
class Tokenizer(Tokenizer):
ESCAPE = "\\"
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
class Generator(Generator):
class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,

View file

@ -1,10 +1,12 @@
class Table:
def __init__(self, *columns, rows=None):
self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self)
def append(self, row):
@ -29,15 +31,22 @@ class Table:
return self.reader
def __repr__(self):
widths = {column: len(column) for column in self.columns}
lines = [" ".join(column for column in self.columns)]
columns = tuple(
column
for i, column in enumerate(self.columns)
if not self.column_range or i in self.column_range
)
widths = {column: len(column) for column in columns}
lines = [" ".join(column for column in columns)]
for i, row in enumerate(self):
if i > 10:
break
lines.append(
" ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
" ".join(
str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
)
)
return "\n".join(lines)
@ -70,8 +79,10 @@ class RangeReader:
class RowReader:
def __init__(self, columns):
self.columns = {column: i for i, column in enumerate(columns)}
def __init__(self, columns, column_range=None):
self.columns = {
column: i for i, column in enumerate(columns) if not column_range or i in column_range
}
self.row = None
def __getitem__(self, column):

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import datetime
import numbers
import re
import typing as t
from collections import deque
from copy import deepcopy
from enum import auto
@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_list,
list_get,
ensure_collection,
seq_get,
split_num_words,
subclasses,
)
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
or optional (False).
"""
key = None
key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type")
__slots__ = ("args", "parent", "arg_key", "type", "comment")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comment = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
def __hash__(self):
def __hash__(self) -> int:
return hash(
(
self.key,
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
tuple(
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
),
)
)
@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
return field.this
return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
def __deepcopy__(self, memo):
return self.__class__(**deepcopy(self.args))
copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.type = self.type
return copy
def copy(self):
new = deepcopy(self)
@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
return
for k, v in self.args.items():
nodes = ensure_list(v)
for node in nodes:
for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
if isinstance(item, Expression):
for k, v in item.args.items():
nodes = ensure_list(v)
for node in nodes:
for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self.to_s()
def sql(self, dialect=None, **opts):
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect)().generate(self, **opts)
def to_s(self, hide_missing=True, level=0):
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
for v in ensure_collection(vs)
if v is not None
)
for k, vs in self.args.items()
}
args["comment"] = self.comment
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
pass
class Annotation(Expression):
arg_types = {
"this": True,
"expression": True,
}
@property
def alias(self):
return self.expression.alias_or_name
class Cache(Expression):
arg_types = {
"with": False,
@ -623,6 +635,38 @@ class Describe(Expression):
pass
class Set(Expression):
arg_types = {"expressions": True}
class SetItem(Expression):
arg_types = {
"this": True,
"kind": False,
"collate": False, # MySQL SET NAMES statement
}
class Show(Expression):
arg_types = {
"this": True,
"target": False,
"offset": False,
"limit": False,
"like": False,
"where": False,
"db": False,
"full": False,
"mutex": False,
"query": False,
"channel": False,
"global": False,
"log": False,
"position": False,
"types": False,
}
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
@ -864,18 +908,20 @@ class Literal(Condition):
def __eq__(self, other):
return (
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@classmethod
def number(cls, number):
def number(cls, number) -> Literal:
return cls(this=str(number), is_string=False)
@classmethod
def string(cls, string):
def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
@ -1087,7 +1133,7 @@ class Properties(Expression):
}
@classmethod
def from_dict(cls, properties_dict):
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
@ -1323,7 +1369,7 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the FROM expression.
@ -1356,7 +1402,7 @@ class Select(Subqueryable):
**opts,
)
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the GROUP BY expression.
@ -1392,7 +1438,7 @@ class Select(Subqueryable):
**opts,
)
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the ORDER BY expression.
@ -1425,7 +1471,7 @@ class Select(Subqueryable):
**opts,
)
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the SORT BY expression.
@ -1458,7 +1504,7 @@ class Select(Subqueryable):
**opts,
)
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the CLUSTER BY expression.
@ -1491,7 +1537,7 @@ class Select(Subqueryable):
**opts,
)
def limit(self, expression, dialect=None, copy=True, **opts):
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
@ -1522,7 +1568,7 @@ class Select(Subqueryable):
**opts,
)
def offset(self, expression, dialect=None, copy=True, **opts):
def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the OFFSET expression.
@ -1553,7 +1599,7 @@ class Select(Subqueryable):
**opts,
)
def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the SELECT expressions.
@ -1583,7 +1629,7 @@ class Select(Subqueryable):
**opts,
)
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the LATERAL expressions.
@ -1626,7 +1672,7 @@ class Select(Subqueryable):
dialect=None,
copy=True,
**opts,
):
) -> Select:
"""
Append to or set the JOIN expressions.
@ -1672,7 +1718,7 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if side:
@ -1681,12 +1727,12 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
on = and_(*ensure_list(on), dialect=dialect, **opts)
on = and_(*ensure_collection(on), dialect=dialect, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
*ensure_list(using),
*ensure_collection(using),
instance=join,
arg="using",
append=append,
@ -1705,7 +1751,7 @@ class Select(Subqueryable):
**opts,
)
def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the WHERE expressions.
@ -1737,7 +1783,7 @@ class Select(Subqueryable):
**opts,
)
def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the HAVING expressions.
@ -1769,7 +1815,7 @@ class Select(Subqueryable):
**opts,
)
def distinct(self, distinct=True, copy=True):
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@ -1788,7 +1834,7 @@ class Select(Subqueryable):
instance.set("distinct", Distinct() if distinct else None)
return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
"""
Convert this expression to a CREATE TABLE AS statement.
@ -1826,11 +1872,11 @@ class Select(Subqueryable):
)
@property
def named_selects(self):
def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
@property
def selects(self):
def selects(self) -> t.List[Expression]:
return self.expressions
@ -1910,12 +1956,16 @@ class Parameter(Expression):
pass
class SessionParameter(Expression):
arg_types = {"this": True, "kind": False}
class Placeholder(Expression):
arg_types = {"this": False}
class Null(Condition):
arg_types = {}
arg_types: t.Dict[str, t.Any] = {}
class Boolean(Condition):
@ -1936,6 +1986,7 @@ class DataType(Expression):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
TINYINT = auto()
SMALLINT = auto()
@ -1975,7 +2026,7 @@ class DataType(Expression):
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
def build(cls, dtype, **kwargs):
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
pass
class NullSafeEQ(Binary, Predicate):
pass
class NullSafeNEQ(Binary, Predicate):
pass
class Distance(Binary):
pass
class Escape(Binary):
pass
@ -2101,18 +2164,14 @@ class Is(Binary, Predicate):
pass
class Kwarg(Binary):
"""Kwarg in special functions like func(kwarg => y)."""
class Like(Binary, Predicate):
pass
class SimilarTo(Binary, Predicate):
pass
class Distance(Binary):
pass
class LT(Binary, Predicate):
pass
@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
pass
class SimilarTo(Binary, Predicate):
pass
class Sub(Binary):
pass
@ -2189,7 +2252,13 @@ class Distinct(Expression):
class In(Predicate):
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
arg_types = {
"this": True,
"expressions": False,
"query": False,
"unnest": False,
"field": False,
}
class TimeUnit(Expression):
@ -2255,7 +2324,9 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
raise NotImplementedError("SQL name is only supported by concrete function implementations")
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
class DateTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
class DateTrunc(Func):
arg_types = {"this": True, "expression": True, "zone": False}
class DatetimeAdd(Func, TimeUnit):
@ -2791,6 +2862,10 @@ class Year(Func):
pass
class Use(Expression):
pass
def _norm_args(expression):
args = {}
@ -2822,7 +2897,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
):
) -> t.Optional[Expression]:
"""Gracefully handle a possible string or expression.
Example:
@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
def select(*expressions, dialect=None, **opts):
def select(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
return Select().select(*expressions, dialect=dialect, **opts)
def from_(*expressions, dialect=None, **opts):
def from_(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts):
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
"""
Creates an update statement.
@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
return update
def delete(table, where=None, dialect=None, **opts):
def delete(table, where=None, dialect=None, **opts) -> Delete:
"""
Builds a delete statement.
@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
)
def condition(expression, dialect=None, **opts):
def condition(expression, dialect=None, **opts) -> Condition:
"""
Initialize a logical condition expression.
@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
Returns:
Condition: the expression
"""
return maybe_parse(
return maybe_parse( # type: ignore
expression,
into=Condition,
dialect=dialect,
@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
)
def and_(*expressions, dialect=None, **opts):
def and_(*expressions, dialect=None, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
return _combine(expressions, And, dialect, **opts)
def or_(*expressions, dialect=None, **opts):
def or_(*expressions, dialect=None, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
return _combine(expressions, Or, dialect, **opts)
def not_(expression, dialect=None, **opts):
def not_(expression, dialect=None, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
return Not(this=_wrap_operator(this))
def paren(expression):
def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None):
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None:
return None
if isinstance(alias, Identifier):
@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
return identifier
def to_table(sql_path: str, **kwargs) -> Table:
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned.
Args:
sql_path(str|Table): `[catalog].[schema].[table]` string
sql_path: a `[catalog].[schema].[table]` string.
Returns:
Table: A table expression
A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
return sql_path
@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts)
def column(col, table=None, quoted=None):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
)
def table_(table, db=None, catalog=None, quoted=None, alias=None):
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
Args:
@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
)
def values(values, alias=None):
def values(values, alias=None) -> Values:
"""Build VALUES statement.
Example:
@ -3449,7 +3527,7 @@ def values(values, alias=None):
)
def convert(value):
def convert(value) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
cns = ensure_list(fun(cn))
for child_node in cns:
for child_node in ensure_collection(fun(cn)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k
else:
new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression):
@ -3529,7 +3606,7 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
def table_name(table):
def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
@ -3546,6 +3623,9 @@ def table_name(table):
table = maybe_parse(table, into=Table)
if not table:
raise ValueError(f"Cannot parse {table}")
return ".".join(
part
for part in (

View file

@ -1,4 +1,8 @@
from __future__ import annotations
import logging
import re
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator:
"""
@ -47,8 +53,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
comments: Whether or not to preserve comments in the ouput SQL code.
Default: True
"""
@ -65,14 +70,16 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
# can override in dialects
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
# whether or not null ordering is supported in order by
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
# Always do union distinct or union all
EXPLICIT_UNION = False
# wrap derived values in parens, usually standard but spark doesn't support it
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = {
@ -80,7 +87,7 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
}
TOKEN_MAPPING = {}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@ -96,6 +103,8 @@ class Generator:
exp.TableFormatProperty,
}
WITH_SEPARATED_COMMENTS = (exp.Select,)
__slots__ = (
"time_mapping",
"time_trie",
@ -122,7 +131,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
"_annotations",
"_comments",
)
def __init__(
@ -148,7 +157,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
annotations=True,
comments=True,
):
import sqlglot
@ -177,7 +186,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._annotations = annotations
self._comments = comments
def generate(self, expression):
"""
@ -204,7 +213,6 @@ class Generator:
return sql
def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@ -215,9 +223,31 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None
if not comment:
return sql
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
if not self.pretty:
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}"
def wrap(self, expression):
this_sql = self.indent(
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
level=1,
pad=0,
)
@ -251,7 +281,7 @@ class Generator:
for i, line in enumerate(lines)
)
def sql(self, expression, key=None):
def sql(self, expression, key=None, comment=True):
if not expression:
return ""
@ -264,29 +294,24 @@ class Generator:
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
return transform(self, expression)
if transform:
return transform
sql = transform(self, expression)
elif transform:
sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
if not isinstance(expression, exp.Expression):
if hasattr(self, exp_handler_name):
sql = getattr(self, exp_handler_name)(expression)
elif isinstance(expression, exp.Func):
sql = self.function_fallback_sql(expression)
elif isinstance(expression, exp.Property):
sql = self.property_sql(expression)
else:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
return getattr(self, exp_handler_name)(expression)
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
if self._annotations and self.pretty:
return f"{self.sql(expression, 'expression')} # {expression.name}"
return self.sql(expression, "expression")
return self.maybe_comment(sql, expression) if self._comments and comment else sql
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@ -371,7 +396,9 @@ class Generator:
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@ -434,7 +461,9 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
@ -481,15 +510,18 @@ class Generator:
return f"{this} ON {table} {columns}"
def identifier_sql(self, expression):
value = expression.name
value = value.lower() if self.normalize else value
text = expression.name
text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
return f"{self.identifier_start}{value}{self.identifier_end}"
return value
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
def partition_sql(self, expression):
keys = csv(
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
for prop in expression.this
]
)
return f"PARTITION({keys})"
@ -504,9 +536,9 @@ class Generator:
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties)
)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties):
if properties.expressions:
@ -551,7 +583,9 @@ class Generator:
this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
partition_sql = (
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
@ -669,7 +703,9 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@ -711,10 +747,10 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
def lambda_sql(self, expression):
def lambda_sql(self, expression, arrow_sep="->"):
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
def lateral_sql(self, expression):
this = self.sql(expression, "this")
@ -748,7 +784,7 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
return f"{self.quote_start}{text}{self.quote_end}"
text = f"{self.quote_start}{text}{self.quote_end}"
return text
def loaddata_sql(self, expression):
@ -796,13 +832,21 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
nulls_sort_change = " NULLS FIRST"
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
elif (
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@ -835,7 +879,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "from"),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@ -858,6 +902,13 @@ class Generator:
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def sessionparameter_sql(self, expression):
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"
@ -931,7 +982,10 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@ -1020,7 +1074,9 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
def in_sql(self, expression):
query = expression.args.get("query")
@ -1196,6 +1252,12 @@ class Generator:
def neq_sql(self, expression):
return self.binary(expression, "<>")
def nullsafeeq_sql(self, expression):
return self.binary(expression, "IS NOT DISTINCT FROM")
def nullsafeneq_sql(self, expression):
return self.binary(expression, "IS DISTINCT FROM")
def or_sql(self, expression):
return self.connector_sql(expression, "OR")
@ -1205,6 +1267,9 @@ class Generator:
def trycast_sql(self, expression):
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def use_sql(self, expression):
return f"USE {self.sql(expression, 'this')}"
def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
@ -1240,17 +1305,27 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
sql = (self.sql(e) for e in expressions)
# the only time leading_comma changes the output is if pretty print is enabled
if self._leading_comma and self.pretty:
pad = " " * self.pad
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
else:
expressions = self.sep(sep).join(sql)
num_sqls = len(expressions)
if indent:
return self.indent(expressions, skip_first=False)
return expressions
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
pad = " " * self.pad
stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
def op_expressions(self, op, expression, flat=False):
expressions_sql = self.expressions(expression, flat=flat)
@ -1264,7 +1339,9 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
@ -1283,3 +1360,6 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
def kwarg_sql(self, expression):
return self.binary(expression, "=>")

View file

@ -1,48 +1,125 @@
from __future__ import annotations
import inspect
import logging
import re
import sys
import typing as t
from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
if t.TYPE_CHECKING:
from sqlglot.expressions import Expression, Table
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
class AutoName(Enum):
def _generate_next_value_(name, _start, _count, _last_values):
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
return name
def list_get(arr, index):
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try:
return arr[index]
return seq[index]
except IndexError:
return None
@t.overload
def ensure_list(value: t.Collection[T]) -> t.List[T]:
...
@t.overload
def ensure_list(value: T) -> t.List[T]:
...
def ensure_list(value):
"""
Ensures that a value is a list, otherwise casts or wraps it into one.
Args:
value: the value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
"""
if value is None:
return []
return value if isinstance(value, (list, tuple, set)) else [value]
elif isinstance(value, (list, tuple)):
return list(value)
return [value]
def csv(*args, sep=", "):
@t.overload
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
...
@t.overload
def ensure_collection(value: T) -> t.Collection[T]:
...
def ensure_collection(value):
"""
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
Args:
value: the value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
"""
if value is None:
return []
return (
value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
)
def csv(*args, sep: str = ", ") -> str:
"""
Formats any number of string arguments as CSV.
Args:
args: the string arguments to format.
sep: the argument separator.
Returns:
The arguments formatted as a CSV string.
"""
return sep.join(arg for arg in args if arg)
def subclasses(module_name, classes, exclude=()):
def subclasses(
module_name: str,
classes: t.Type | t.Tuple[t.Type, ...],
exclude: t.Type | t.Tuple[t.Type, ...] = (),
) -> t.List[t.Type]:
"""
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
Returns all subclasses for a collection of classes, possibly excluding some of them.
Args:
module_name (str): The name of the module to search for subclasses in.
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
module_name: the name of the module to search for subclasses in.
classes: class(es) we want to find the subclasses of.
exclude: class(es) we want to exclude from the returned list.
Returns:
A list of all the target subclasses.
The target subclasses.
"""
return [
obj
@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
]
def apply_index_offset(expressions, offset):
def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
"""
Applies an offset to a given integer literal expression.
Args:
expressions: the expression the offset will be applied to, wrapped in a list.
offset: the offset that will be applied.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
`expressions` argument contains more than one expressions, it's returned unaffected.
"""
if not offset or len(expressions) != 1:
return expressions
@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset)
return [expression]
return expressions
def camel_to_snake_case(name):
def camel_to_snake_case(name: str) -> str:
"""Converts `name` from camelCase to snake_case and returns the result."""
return CAMEL_CASE_PATTERN.sub("_", name).upper()
def while_changing(expression, func):
def while_changing(
expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
) -> E:
"""
Applies a transformation to a given expression until a fix point is reached.
Args:
expression: the expression to be transformed.
func: the transformation to be applied.
Returns:
The transformed expression.
"""
while True:
start = hash(expression)
expression = func(expression)
@ -80,10 +182,19 @@ def while_changing(expression, func):
return expression
def tsort(dag):
def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
Args:
dag: the graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
"""
result = []
def visit(node, visited):
def visit(node: T, visited: t.Set[T]) -> None:
if node in result:
return
if node in visited:
@ -103,10 +214,8 @@ def tsort(dag):
return result
def open_file(file_name):
"""
Open a file that may be compressed as gzip and return in newline mode.
"""
def open_file(file_name: str) -> t.TextIO:
"""Open a file that may be compressed as gzip and return it in universal newline mode."""
with open(file_name, "rb") as f:
gzipped = f.read(2) == b"\x1f\x8b"
@ -119,14 +228,14 @@ def open_file(file_name):
@contextmanager
def csv_reader(table):
def csv_reader(table: Table) -> t.Any:
"""
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
table (exp.Table): A table expression with an anonymous function READ_CSV in it
table: a `Table` expression with an anonymous function `READ_CSV` in it.
Returns:
Yields:
A python csv reader.
"""
file, *args = table.this.expressions
@ -147,13 +256,16 @@ def csv_reader(table):
file.close()
def find_new_name(taken, base):
def find_new_name(taken: t.Sequence[str], base: str) -> str:
"""
Searches for a new name.
Args:
taken (Sequence[str]): set of taken names
base (str): base name to alter
taken: a collection of taken names.
base: base name to alter.
Returns:
The new, available name.
"""
if base not in taken:
return base
@ -163,22 +275,26 @@ def find_new_name(taken, base):
while new in taken:
i += 1
new = f"{base}_{i}"
return new
def object_to_dict(obj, **kwargs):
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
def split_num_words(
value: str, sep: str, min_num_words: int, fill_from_start: bool = True
) -> t.List[t.Optional[str]]:
"""
Perform a split on a value and return N words as a result with None used for words that don't exist.
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args:
value: The value to be split
sep: The value to use to split on
min_num_words: The minimum number of words that are going to be in the result
fill_from_start: Indicates that if None values should be inserted at the start or end of the list
value: the value to be split.
sep: the value to use to split on.
min_num_words: the minimum number of words that are going to be in the result.
fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
Returns:
The list of words returned by `split`, possibly augmented by a number of `None` values.
"""
words = value.split(sep)
if fill_from_start:
@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
def is_iterable(value: t.Any) -> bool:
"""
Checks if the value is an iterable but does not include strings and bytes
Checks if the value is an iterable, excluding the types `str` and `bytes`.
Examples:
>>> is_iterable([1,2])
@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
False
Args:
value: The value to check if it is an interable
value: the value to check if it is an iterable.
Returns: Bool indicating if it is an iterable
Returns:
A `bool` value indicating if it is an iterable.
"""
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
"""
Flattens a list that can contain both iterables and non-iterable elements
Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
type `str` and `bytes` are not regarded as iterables.
Examples:
>>> list(flatten([[1, 2], 3]))
[1, 2, 3]
>>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
[1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Args:
values: The value to be flattened
values: the value to be flattened.
Returns:
Yields non-iterable elements (not including str or byte as iterable)
Yields:
Non-iterable elements in `values`.
"""
for value in values:
if is_iterable(value):

View file

@ -1,5 +1,5 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BOOLEAN
),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.StrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
for v in ensure_collection(value):
self._maybe_annotate(v)
return expression
@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression.type = exp.DataType.build(
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
)
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args):
self._annotate_args(expression)
expressions = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
last_datatype = None
for expr in expressions:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
return expression

View file

@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
else:
on_clause_columns = set()
return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
return any(
column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
)
def _is_joined_on_all_unique_outputs(scope, join):

View file

@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
# All table names are taken
for scope in root.traverse():
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
taken.update(
{
source.name: source
for _, source in scope.sources.items()
if isinstance(source, exp.Table)
}
)
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
for scope in itertools.chain(
root.union_scopes, root.subquery_scopes, root.derived_table_scopes
):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:

View file

@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
unmergable_window_columns = [
column
for column in outer_scope.columns
if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
if column.find_ancestor(
exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
)
]
window_expressions_in_unmergable = [
column
@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
and any(
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
and not _is_a_window_expression_in_unmergable_operation()
)
@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
if table.alias_or_name == node_to_replace.alias_or_name:
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
outer_scope.add_source(
new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
)
def _merge_joins(outer_scope, inner_scope, from_or_join):
@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
inner_scope (sqlglot.optimizer.scope.Scope)
"""
if (
any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
any(
outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
)
or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
):

View file

@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf):

View file

@ -68,4 +68,8 @@ def normalize(expression):
def other_table_names(join, exclude):
return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)
return expression

View file

@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
for name, node in nodes.items():
@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
has_window_expression = any(select for select in node.selects if select.find(exp.Window))
has_window_expression = any(
select for select in node.selects if select.find(exp.Window)
)
# we can't push down predicates to select statements if they are referenced in
# multiple places.
if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
if (
not node.args.get("group")
and scope_ref_count[id(source)] < 2
and not has_window_expression
):
nodes[table] = node
return nodes
@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name]
return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)

View file

@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
def _remove_indexed_selections(scope, indexes_to_remove):
new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
new_selections = [
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)

View file

@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
if (
not column.table
and column.parent is not ordered
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column):
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
if (
not column.table
and column.find_ancestor(exp.AggFunc)
and column.name in resolver.all_columns
):
columns_missing_from_scope.append(column)
for column in columns_missing_from_scope:
@ -295,7 +303,9 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@ -343,14 +353,18 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
self._all_columns = set(
column for columns in self._get_all_source_columns().values() for column in columns
)
return self._all_columns
def get_source_columns(self, name, only_visible=False):
@ -377,7 +391,9 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):

View file

@ -226,7 +226,9 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
external_columns = [
column for scope in self.subquery_scopes for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
@ -278,7 +280,11 @@ class Scope:
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
return {
alias: scope
for alias, scope in self.sources.items()
if isinstance(scope, Scope) and scope.is_cte
}
@property
def selects(self):
@ -307,7 +313,9 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
@property

View file

@ -229,7 +229,9 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
elif isinstance(expression, exp.NullSafeEQ):
if a == b:
return TRUE
elif isinstance(expression, exp.NullSafeNEQ):
if a == b:
return FALSE
elif NULL in (a, b):
return NULL
@ -357,7 +365,7 @@ def extract_date(cast):
def extract_interval(interval):
try:
from dateutil.relativedelta import relativedelta
from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return None

View file

@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else:
return
@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,

View file

@ -1,9 +1,13 @@
from __future__ import annotations
import logging
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
from sqlglot.helper import apply_index_offset, ensure_list, list_get
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
@ -20,7 +24,15 @@ def parse_var_map(args):
)
class Parser:
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
return klass
class Parser(metaclass=_Parser):
"""
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
and produces a parsed syntax tree.
@ -45,16 +57,16 @@ class Parser:
FUNCTIONS = {
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
"DATE_TO_DATE_STR": lambda args: exp.Cast(
this=list_get(args, 0),
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=list_get(args, 0),
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
this=exp.Cast(
this=list_get(args, 0),
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
start=exp.Literal.number(1),
@ -90,6 +102,7 @@ class Parser:
TokenType.NVARCHAR,
TokenType.TEXT,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
TokenType.INTERVAL,
TokenType.TIMESTAMP,
@ -243,6 +256,7 @@ class Parser:
EQUALITY = {
TokenType.EQ: exp.EQ,
TokenType.NEQ: exp.NEQ,
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
}
COMPARISON = {
@ -298,6 +312,21 @@ class Parser:
TokenType.ANTI,
}
LAMBDAS = {
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
this=self._parse_conjunction().transform(
self._replace_lambda, {node.name for node in expressions}
),
expressions=expressions,
),
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg,
this=exp.Var(this=expressions[0].name),
expression=self._parse_conjunction(),
),
}
COLUMN_OPERATORS = {
TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression(
@ -362,20 +391,30 @@ class Parser:
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(),
}
PRIMARY_PARSERS = {
TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
TokenType.NULL: lambda *_: exp.Null(),
TokenType.TRUE: lambda *_: exp.Boolean(this=True),
TokenType.FALSE: lambda *_: exp.Boolean(this=False),
TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
TokenType.STRING: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=True
),
TokenType.NUMBER: lambda self, token: self.expression(
exp.Literal, this=token.text, is_string=False
),
TokenType.STAR: lambda self, _: self.expression(
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
),
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.PARAMETER: lambda self, _: self.expression(
exp.Parameter, this=self._parse_var() or self._parse_primary()
),
TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
RANGE_PARSERS = {
@ -411,16 +450,24 @@ class Parser:
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
exp.TableFormatProperty
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
TokenType.IMMUTABLE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.STABLE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
}
CONSTRAINT_PARSERS = {
@ -450,7 +497,8 @@ class Parser:
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
"window": lambda self: self._match(TokenType.WINDOW)
and self._parse_window(self._parse_id_var(), alias=True),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@ -459,6 +507,9 @@ class Parser:
"offset": lambda self: self._parse_offset(),
}
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
SET_PARSERS: t.Dict[str, t.Callable] = {}
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
@ -488,7 +539,9 @@ class Parser:
"_curr",
"_next",
"_prev",
"_greedy_subqueries",
"_prev_comment",
"_show_trie",
"_set_trie",
)
def __init__(
@ -519,7 +572,7 @@ class Parser:
self._curr = None
self._next = None
self._prev = None
self._greedy_subqueries = False
self._prev_comment = None
def parse(self, raw_tokens, sql=None):
"""
@ -533,10 +586,12 @@ class Parser:
Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
"""
return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
)
def parse_into(self, expression_types, raw_tokens, sql=None):
for expression_type in ensure_list(expression_types):
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
raise TypeError(f"No parser registered for {expression_type}")
@ -597,6 +652,9 @@ class Parser:
def expression(self, exp_class, **kwargs):
instance = exp_class(**kwargs)
if self._prev_comment:
instance.comment = self._prev_comment
self._prev_comment = None
self.validate_expression(instance)
return instance
@ -633,14 +691,16 @@ class Parser:
return index
def _get_token(self, index):
return list_get(self._tokens, index)
def _advance(self, times=1):
self._index += times
self._curr = self._get_token(self._index)
self._next = self._get_token(self._index + 1)
self._prev = self._get_token(self._index - 1) if self._index > 0 else None
self._curr = seq_get(self._tokens, self._index)
self._next = seq_get(self._tokens, self._index + 1)
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comment = self._prev.comment
else:
self._prev = None
self._prev_comment = None
def _retreat(self, index):
self._advance(index - self._index)
@ -661,6 +721,7 @@ class Parser:
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
self._parse_query_modifiers(expression)
return expression
@ -682,7 +743,11 @@ class Parser:
)
def _parse_exists(self, not_=False):
return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
return (
self._match(TokenType.IF)
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@ -931,7 +996,9 @@ class Parser:
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
using=self._parse_csv(
lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
),
where=self._parse_where(),
)
@ -983,11 +1050,13 @@ class Parser:
return None
def parse_values():
k = self._parse_var()
key = self._parse_var()
value = None
if self._match(TokenType.EQ):
v = self._parse_string()
return (k, v)
return (k, None)
value = self._parse_string()
return exp.Property(this=key, value=value)
self._match_l_paren()
values = self._parse_csv(parse_values)
@ -1019,6 +1088,8 @@ class Parser:
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
comment = self._prev_comment
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@ -1033,7 +1104,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
expressions = self._parse_csv(self._parse_expression)
this = self.expression(
exp.Select,
@ -1042,6 +1113,7 @@ class Parser:
expressions=expressions,
limit=limit,
)
this.comment = comment
from_ = self._parse_from()
if from_:
this.set("from", from_)
@ -1072,8 +1144,10 @@ class Parser:
while True:
expressions.append(self._parse_cte())
if not self._match(TokenType.COMMA):
if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
break
else:
self._match(TokenType.WITH)
return self.expression(
exp.With,
@ -1111,11 +1185,7 @@ class Parser:
if not alias and not columns:
return None
return self.expression(
exp.TableAlias,
this=alias,
columns=columns,
)
return self.expression(exp.TableAlias, this=alias, columns=columns)
def _parse_subquery(self, this):
return self.expression(
@ -1150,12 +1220,6 @@ class Parser:
if expression:
this.set(key, expression)
def _parse_annotation(self, expression):
if self._match(TokenType.ANNOTATION):
return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
return expression
def _parse_hint(self):
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
@ -1295,7 +1359,9 @@ class Parser:
if not table:
self.raise_error("Expected table name")
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
)
if schema:
return self._parse_schema(this=this)
@ -1500,7 +1566,9 @@ class Parser:
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
return self.expression(
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
@ -1521,7 +1589,8 @@ class Parser:
if (
not explicitly_null_ordered
and (
(asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
(asc and self.null_ordering == "nulls_are_small")
or (desc and self.null_ordering != "nulls_are_small")
)
and self.null_ordering != "nulls_are_last"
):
@ -1606,6 +1675,9 @@ class Parser:
def _parse_is(self, this):
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression(
exp.Is,
this=this,
@ -1653,9 +1725,13 @@ class Parser:
expression=self._parse_term(),
)
elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
this = self.expression(
exp.BitwiseLeftShift, this=this, expression=self._parse_term()
)
elif self._match_pair(TokenType.GT, TokenType.GT):
this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
this = self.expression(
exp.BitwiseRightShift, this=this, expression=self._parse_term()
)
else:
break
@ -1685,7 +1761,7 @@ class Parser:
)
index = self._index
type_token = self._parse_types()
type_token = self._parse_types(check_func=True)
this = self._parse_column()
if type_token:
@ -1698,7 +1774,7 @@ class Parser:
return this
def _parse_types(self):
def _parse_types(self, check_func=False):
index = self._index
if not self._match_set(self.TYPE_TOKENS):
@ -1708,10 +1784,13 @@ class Parser:
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
expressions = None
maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value)],
nested=True,
)
if self._match(TokenType.L_BRACKET):
@ -1731,6 +1810,7 @@ class Parser:
return None
self._match_r_paren()
maybe_func = True
if nested and self._match(TokenType.LT):
if is_struct:
@ -1741,26 +1821,47 @@ class Parser:
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
value = None
if type_token in self.TIMESTAMPS:
tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
if tz:
return exp.DataType(
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
if ltz:
return exp.DataType(
elif (
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
):
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
self._match(TokenType.WITHOUT_TIME_ZONE)
return exp.DataType(
elif self._match(TokenType.WITHOUT_TIME_ZONE):
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
maybe_func = maybe_func and value is None
if value is None:
value = exp.DataType(
this=exp.DataType.Type.TIMESTAMP,
expressions=expressions,
)
if maybe_func and check_func:
index2 = self._index
peek = self._parse_string()
if not peek:
self._retreat(index)
return None
self._retreat(index2)
if value:
return value
return exp.DataType(
this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions,
@ -1826,22 +1927,29 @@ class Parser:
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
comment = self._prev_comment
query = self._parse_select()
if query:
expressions = [query]
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
expressions = self._parse_csv(
lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
)
this = list_get(expressions, 0)
this = seq_get(expressions, 0)
self._parse_query_modifiers(this)
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
return self._parse_set_operations(self._parse_subquery(this))
if len(expressions) > 1:
return self.expression(exp.Tuple, expressions=expressions)
return self.expression(exp.Paren, this=this)
this = self._parse_set_operations(self._parse_subquery(this))
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
if comment:
this.comment = comment
return this
return None
@ -1894,7 +2002,8 @@ class Parser:
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
self._match_r_paren()
self._match_r_paren(this)
return self._parse_window(this)
def _parse_user_defined_function(self):
@ -1920,6 +2029,18 @@ class Parser:
return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(
exp.SessionParameter,
this=this,
kind=kind,
)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
@ -1938,11 +2059,15 @@ class Parser:
else:
expressions = [self._parse_id_var()]
if not self._match(TokenType.ARROW):
if self._match_set(self.LAMBDAS):
return self.LAMBDAS[self._prev.token_type](self, expressions)
self._retreat(index)
if self._match(TokenType.DISTINCT):
this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction))
this = self.expression(
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
)
else:
this = self._parse_conjunction()
@ -1953,20 +2078,15 @@ class Parser:
return self._parse_alias(self._parse_limit(self._parse_order(this)))
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
return self.expression(
exp.Lambda,
this=conjunction,
expressions=expressions,
)
def _parse_schema(self, this=None):
index = self._index
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
self._retreat(index)
return this
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
args = self._parse_csv(
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
)
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@ -2104,6 +2224,7 @@ class Parser:
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
this.comment = self._prev_comment
return self._parse_bracket(this)
def _parse_case(self):
@ -2124,7 +2245,9 @@ class Parser:
if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
return self._parse_window(
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
)
def _parse_if(self):
if self._match(TokenType.L_PAREN):
@ -2331,7 +2454,9 @@ class Parser:
self._match(TokenType.BETWEEN)
return {
"value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
"value": (
self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
)
or self._parse_bitwise(),
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
@ -2348,7 +2473,7 @@ class Parser:
this=this,
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
)
self._match_r_paren()
self._match_r_paren(aliases)
return aliases
alias = self._parse_id_var(any_token)
@ -2365,28 +2490,29 @@ class Parser:
return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
self._advance()
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
return None
return exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self):
if self._match(TokenType.STRING):
return exp.Literal.string(self._prev.text)
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
def _parse_number(self):
if self._match(TokenType.NUMBER):
return exp.Literal.number(self._prev.text)
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
return self._parse_placeholder()
def _parse_identifier(self):
if self._match(TokenType.IDENTIFIER):
return exp.Identifier(this=self._prev.text, quoted=True)
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
def _parse_var(self):
if self._match(TokenType.VAR):
return exp.Var(this=self._prev.text)
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
def _parse_var_or_string(self):
@ -2394,27 +2520,27 @@ class Parser:
def _parse_null(self):
if self._match(TokenType.NULL):
return exp.Null()
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None
def _parse_boolean(self):
if self._match(TokenType.TRUE):
return exp.Boolean(this=True)
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
return exp.Boolean(this=False)
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None
def _parse_star(self):
if self._match(TokenType.STAR):
return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER):
return exp.Placeholder()
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
self._advance()
return exp.Placeholder(this=self._prev.text)
return self.expression(exp.Placeholder, this=self._prev.text)
return None
def _parse_except(self):
@ -2432,22 +2558,27 @@ class Parser:
self._match_r_paren()
return columns
def _parse_csv(self, parse):
parse_result = parse()
def _parse_csv(self, parse_method):
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
while self._match(TokenType.COMMA):
parse_result = parse()
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
parse_result = parse_method()
if parse_result is not None:
items.append(parse_result)
return items
def _parse_tokens(self, parse, expressions):
this = parse()
def _parse_tokens(self, parse_method, expressions):
this = parse_method()
while self._match_set(expressions):
this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
this = self.expression(
expressions[self._prev.token_type], this=this, expression=parse_method()
)
return this
@ -2460,6 +2591,47 @@ class Parser:
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
def _parse_use(self):
return self.expression(exp.Use, this=self._parse_id_var())
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
def _default_parse_set_item(self):
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)
def _parse_set_item(self):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()
def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
def _find_parser(self, parsers, trie):
index = self._index
this = []
while True:
# The current token might be multiple words
curr = self._curr.text.upper()
key = curr.split(" ")
this.append(curr)
self._advance()
result, trie = in_trie(trie, key)
if result == 0:
break
if result == 2:
subparser = parsers[" ".join(this)]
return subparser
self._retreat(index)
return None
def _match(self, token_type):
if not self._curr:
return None
@ -2491,13 +2663,17 @@ class Parser:
return None
def _match_l_paren(self):
def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
if expression and self._prev_comment:
expression.comment = self._prev_comment
def _match_r_paren(self):
def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
if expression and self._prev_comment:
expression.comment = self._prev_comment
def _match_text(self, *texts):
index = self._index

View file

@ -72,7 +72,9 @@ class Step:
if from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
raise UnsupportedError(
"Multi-from statements are unsupported. Run it through the optimizer"
)
step = Scan.from_expression(from_[0], ctes)
else:
@ -102,7 +104,7 @@ class Step:
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
operand.replace(exp.column(operands[operand], step.name, quoted=True))
operand.replace(exp.column(operands[operand], quoted=True))
else:
projections.append(e)
@ -117,9 +119,11 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
aggregate.group = group.expressions
aggregate.add_dependency(step)
step = aggregate
@ -136,9 +140,6 @@ class Step:
sort.key = order.expressions
sort.add_dependency(step)
step = sort
for k in sort.key + projections:
for column in k.find_all(exp.Column):
column.set("table", exp.to_identifier(step.name, quoted=True))
step.projections = projections
@ -203,7 +204,9 @@ class Scan(Step):
alias_ = expression.alias
if not alias_:
raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
raise UnsupportedError(
"Tables/Subqueries must be aliased. Run it through the optimizer"
)
if isinstance(expression, exp.Subquery):
table = expression.this

0
sqlglot/py.typed Normal file
View file

View file

@ -1,44 +1,60 @@
from __future__ import annotations
import abc
import typing as t
from sqlglot import expressions as exp
from sqlglot.errors import OptimizeError
from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
TABLE_ARGS = ("this", "db", "catalog")
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def add_table(self, table, column_mapping=None):
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided
Register or update a table. Some implementing classes may require column information to also be provided.
Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
table: table expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
@abc.abstractmethod
def column_names(self, table, only_visible=False):
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
only_visible (bool): Whether to include invisible columns
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
Returns:
list[str]: list of column names
The list of column names.
"""
@abc.abstractmethod
def get_column_type(self, table, column):
def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
"""
Get the exp.DataType type of a column in the schema.
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
Args:
table (sqlglot.expressions.Table): The source table.
column (sqlglot.expressions.Column): The target column.
table: the source table.
column: the target column.
Returns:
sqlglot.expressions.DataType.Type: The resulting column type.
The resulting column type.
"""
@ -60,132 +76,179 @@ class MappingSchema(Schema):
dialect (str): The dialect to be used for custom type mappings.
"""
def __init__(self, schema=None, visible=None, dialect=None):
def __init__(
self,
schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
self.schema = schema or {}
self.visible = visible
self.visible = visible or {}
self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
self._type_mapping_cache = {}
self.supported_table_args = []
self.forbidden_table_args = set()
if self.schema:
self._initialize_supported_args()
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
self._supported_table_args: t.Tuple[str, ...] = tuple()
@classmethod
def from_mapping_schema(cls, mapping_schema):
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
schema=mapping_schema.schema,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
def copy(self, **kwargs):
return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
"schema": self.schema.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
def add_table(self, table, column_mapping=None):
@property
def supported_table_args(self):
if not self._supported_table_args and self.schema:
depth = _dict_depth(self.schema)
if not depth or depth == 1: # {}
self._supported_table_args = tuple()
elif 2 <= depth <= 4:
self._supported_table_args = TABLE_ARGS[: depth - 1]
else:
raise SchemaError(f"Invalid schema shape. Depth: {depth}")
return self._supported_table_args
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
Args:
table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
table = exp.to_table(table)
self._validate_table(table)
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
existing_column_mapping = _nested_get(
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
)
if existing_column_mapping and not column_mapping:
schema = self.find_schema(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
self.schema,
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
list(reversed(self.table_parts(table_))),
column_mapping,
)
self._initialize_supported_args()
self.schema_trie = self._build_trie(self.schema)
def _get_table_args_from_table(self, table):
if table.args.get("catalog") is not None:
return "catalog", "db", "this"
if table.args.get("db") is not None:
return "db", "this"
return ("this",)
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
def _validate_table(self, table):
if not self.supported_table_args and isinstance(table, exp.Table):
return
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
for expected in self.supported_table_args:
if not table.text(expected):
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
if not table_:
raise SchemaError(f"Not a valid table '{table}'")
def column_names(self, table, only_visible=False):
table = exp.to_table(table)
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
return table_
args = tuple(table.text(p) for p in self.supported_table_args)
def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.text(part) for part in TABLE_ARGS if table.text(part)]
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
if not isinstance(table_.this, exp.Identifier):
return fs_get(table) # type: ignore
schema = self.find_schema(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
return columns
return list(schema)
visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
return [col for col in columns if col in visible]
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
def get_column_type(self, table, column):
try:
schema_type = self.schema.get(table.name, {}).get(column.name).upper()
def find_schema(
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
) -> t.Optional[t.Dict[str, str]]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find schema for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
return None
return self._nested_get(parts, raise_on_missing=raise_on_missing)
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find_schema(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column.sql()}")
raise SchemaError(f"Could not convert table '{table}'")
def _convert_type(self, schema_type):
def _convert_type(self, schema_type: str) -> exp.DataType.Type:
"""
Convert a type represented as a string to the corresponding exp.DataType.Type object.
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
Args:
schema_type (str): The type we want to convert.
schema_type: the type we want to convert.
Returns:
sqlglot.expressions.DataType.Type: The resulting expression type.
The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
self._type_mapping_cache[schema_type] = exp.maybe_parse(
schema_type, into=exp.DataType, dialect=self.dialect
).this
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
self._type_mapping_cache[schema_type] = expression.this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
raise SchemaError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
def _initialize_supported_args(self):
if not self.supported_table_args:
depth = _dict_depth(self.schema)
def _build_trie(self, schema: t.Dict):
return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
all_args = ["this", "db", "catalog"]
if not depth or depth == 1: # {}
self.supported_table_args = []
elif 2 <= depth <= 4:
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.schema,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)
def ensure_schema(schema):
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def ensure_column_mapping(mapping):
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
@ -196,7 +259,7 @@ def ensure_column_mapping(mapping):
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
elif mapping is None:
@ -204,7 +267,20 @@ def ensure_column_mapping(mapping):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
def fs_get(table):
def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
tables = []
keys = keys or []
depth = _dict_depth(schema)
for k, v in schema.items():
if depth >= 3:
tables.extend(flatten_schema(v, keys + [k]))
elif depth == 2:
tables.append(keys + [k])
return tables
def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name
if name.upper() == "READ_CSV":
@ -214,21 +290,23 @@ def fs_get(table):
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path, raise_on_missing=True):
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
"""
Get a value for a nested dictionary.
Args:
d (dict): dictionary
*path (tuple[str, str]): tuples of (name, key)
d: the dictionary to search.
*path: tuples of (name, key), where:
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
Returns:
The value or None if it doesn't exist
The value or None if it doesn't exist.
"""
for name, key in path:
d = d.get(key)
d = d.get(key) # type: ignore
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True):
return d
def _nested_set(d, keys, value):
def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
Ex:
Example:
>>> _nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
d (dict): dictionary
keys (Iterable[str]): ordered iterable of keys that makeup path to value
value (Any): The value to set in the dictionary for the given key path
Args:
d: dictionary to update.
keys: the keys that makeup the path to `value`.
value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.
"""
if not keys:
return
return d
if len(keys) == 1:
d[keys[0]] = value
return
return d
subd = d
for key in keys[:-1]:
if key not in subd:
subd = subd.setdefault(key, {})
else:
subd = subd[key]
subd[keys[-1]] = value
return d
def _dict_depth(d):
def _dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.

View file

@ -1,9 +1,13 @@
# the generic time format is based on python time.strftime
import typing as t
# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import in_trie, new_trie
def format_time(string, mapping, trie=None):
def format_time(
string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None
) -> t.Optional[str]:
"""
Converts a time string given a mapping.
@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None):
>>> format_time("%Y", {"%Y": "YYYY"})
'YYYY'
mapping: Dictionary of time format to target time format
trie: Optional trie, can be passed in for performance
Args:
mapping: dictionary of time format to target time format.
trie: optional trie, can be passed in for performance.
Returns:
The converted time string.
"""
if not string:
return None
start = 0
end = 1
size = len(string)

View file

@ -1,3 +1,6 @@
from __future__ import annotations
import typing as t
from enum import auto
from sqlglot.helper import AutoName
@ -27,6 +30,7 @@ class TokenType(AutoName):
NOT = auto()
EQ = auto()
NEQ = auto()
NULLSAFE_EQ = auto()
AND = auto()
OR = auto()
AMP = auto()
@ -36,12 +40,14 @@ class TokenType(AutoName):
TILDA = auto()
ARROW = auto()
DARROW = auto()
FARROW = auto()
HASH = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
LR_ARROW = auto()
ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()
SPACE = auto()
BREAK = auto()
@ -73,7 +79,7 @@ class TokenType(AutoName):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
BYTEA = auto()
VARBINARY = auto()
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@ -142,6 +148,7 @@ class TokenType(AutoName):
DESCRIBE = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
@ -238,6 +245,7 @@ class TokenType(AutoName):
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
ROLLBACK = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
@ -287,37 +295,49 @@ class TokenType(AutoName):
class Token:
__slots__ = ("token_type", "text", "line", "col")
__slots__ = ("token_type", "text", "line", "col", "comment")
@classmethod
def number(cls, number):
def number(cls, number: int) -> Token:
"""Returns a NUMBER token with `number` as its text."""
return cls(TokenType.NUMBER, str(number))
@classmethod
def string(cls, string):
def string(cls, string: str) -> Token:
"""Returns a STRING token with `string` as its text."""
return cls(TokenType.STRING, string)
@classmethod
def identifier(cls, identifier):
def identifier(cls, identifier: str) -> Token:
"""Returns an IDENTIFIER token with `identifier` as its text."""
return cls(TokenType.IDENTIFIER, identifier)
@classmethod
def var(cls, var):
def var(cls, var: str) -> Token:
"""Returns an VAR token with `var` as its text."""
return cls(TokenType.VAR, var)
def __init__(self, token_type, text, line=1, col=1):
def __init__(
self,
token_type: TokenType,
text: str,
line: int = 1,
col: int = 1,
comment: t.Optional[str] = None,
) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
self.comment = comment
def __repr__(self):
def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
def __new__(cls, clsname, bases, attrs):
def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
@ -325,27 +345,29 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._ESCAPES = set(klass.ESCAPES)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
for key, value in {
for key in {
**klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
}.items()
}
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
@staticmethod
def _delimeter_list_to_dict(list):
def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
"*": TokenType.STAR,
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"#": TokenType.ANNOTATION,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
"#": TokenType.HASH,
}
QUOTES = ["'"]
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
BIT_STRINGS = []
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
BYTE_STRINGS = []
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS = ['"']
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
ESCAPE = "'"
ESCAPES = ["'"]
KEYWORDS = {
"/*+": TokenType.HINT,
@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
"<=": TokenType.LTE,
"<>": TokenType.NEQ,
"!=": TokenType.NEQ,
"<=>": TokenType.NULLSAFE_EQ,
"->": TokenType.ARROW,
"->>": TokenType.DARROW,
"=>": TokenType.FARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESCRIBE": TokenType.DESCRIBE,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
"ROLLBACK": TokenType.ROLLBACK,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
"BINARY": TokenType.BINARY,
"BLOB": TokenType.BINARY,
"BYTEA": TokenType.BINARY,
"BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
TokenType.USE,
TokenType.VACUUM,
TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS = {}
ENCODE = None
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
"_comment",
"_char",
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comment",
"_prev_token_type",
"_replace_backslash",
)
def __init__(self):
"""
Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
"""
def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES # type: ignore
self.reset()
def reset(self):
def reset(self) -> None:
self.sql = ""
self.size = 0
self.tokens = []
self.tokens: t.List[Token] = []
self._start = 0
self._current = 0
self._line = 1
self._col = 1
self._comment = None
self._char = None
self._end = None
self._peek = None
self._prev_token_line = -1
self._prev_token_comment = None
self._prev_token_type = None
def tokenize(self, sql):
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
if not self._char:
break
white_space = self.WHITE_SPACE.get(self._char)
identifier_end = self._IDENTIFIERS.get(self._char)
white_space = self.WHITE_SPACE.get(self._char) # type: ignore
identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
elif self._char.isdigit():
elif self._char.isdigit(): # type:ignore
self._scan_number()
elif identifier_end:
self._scan_identifier(identifier_end)
@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_keywords()
return self.tokens
def _chars(self, size):
def _chars(self, size: int) -> str:
if size == 1:
return self._char
return self._char # type: ignore
start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""
def _advance(self, i=1):
def _advance(self, i: int = 1) -> None:
self._col += i
self._current += i
self._end = self._current >= self.size
self._char = self.sql[self._current - 1]
self._peek = self.sql[self._current] if self._current < self.size else ""
self._end = self._current >= self.size # type: ignore
self._char = self.sql[self._current - 1] # type: ignore
self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
@property
def _text(self):
def _text(self) -> str:
return self.sql[self._start : self._current]
def _add(self, token_type, text=None):
self._prev_token_type = token_type
self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comment = self._comment
self._prev_token_type = token_type # type: ignore
self.tokens.append(
Token(
token_type,
self._text if text is None else text,
self._line,
self._col,
self._comment,
)
)
self._comment = None
if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
def _scan_keywords(self):
def _scan_keywords(self) -> None:
size = 0
word = None
chars = self._text
@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
result, trie = in_trie(trie, char.upper())
result, trie = in_trie(trie, char.upper()) # type: ignore
if result == 0:
break
@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
chars = None
chars = None # type: ignore
if not word:
if self._char in self.SINGLE_TOKENS:
token = self.SINGLE_TOKENS[self._char]
if token == TokenType.ANNOTATION:
self._scan_annotation()
return
self._add(token)
self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
return
self._scan_var()
return
@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance(size - 1)
self._add(self.KEYWORDS[word.upper()])
def _scan_comment(self, comment_start):
if comment_start not in self._COMMENTS:
def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS: # type: ignore
return False
comment_end = self._COMMENTS[comment_start]
comment_start_line = self._line
comment_start_size = len(comment_start)
comment_end = self._COMMENTS[comment_start] # type: ignore
if comment_end:
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._advance(comment_end_size - 1)
else:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
self._comment = self._text[comment_start_size:] # type: ignore
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
# types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
self._comment = None
return True
def _scan_annotation(self):
while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
self._advance()
self._add(TokenType.ANNOTATION, self._text[1:])
def _scan_number(self):
def _scan_number(self) -> None:
if self._char == "0":
peek = self._peek.upper()
peek = self._peek.upper() # type: ignore
if peek == "B":
return self._scan_bits()
elif peek == "X":
@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
if self._peek.isdigit():
if self._peek.isdigit(): # type: ignore
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
elif self._peek.upper() == "E" and not scientific:
elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1
self._advance()
elif self._peek.isalpha():
elif self._peek.isalpha(): # type: ignore
self._add(TokenType.NUMBER)
literal = []
while self._peek.isalpha():
literal.append(self._peek.upper())
while self._peek.isalpha(): # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
literal = "".join(literal)
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type:
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
return self._add(token_type, literal) # type: ignore
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)
def _scan_bits(self):
def _scan_bits(self) -> None:
self._advance()
value = self._extract_value()
try:
@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
def _scan_hex(self):
def _scan_hex(self) -> None:
self._advance()
value = self._extract_value()
try:
@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
def _extract_value(self):
def _extract_value(self) -> str:
while True:
char = self._peek.strip()
char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
def _scan_string(self, quote):
quote_end = self._QUOTES.get(quote)
def _scan_string(self, quote: str) -> bool:
quote_end = self._QUOTES.get(quote) # type: ignore
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start):
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS: # type: ignore
delimiters = self._HEX_STRINGS # type: ignore
token_type = TokenType.HEX_STRING
base = 16
elif string_start in self._BIT_STRINGS:
delimiters = self._BIT_STRINGS
elif string_start in self._BIT_STRINGS: # type: ignore
delimiters = self._BIT_STRINGS # type: ignore
token_type = TokenType.BIT_STRING
base = 2
elif string_start in self._BYTE_STRINGS:
delimiters = self._BYTE_STRINGS
elif string_start in self._BYTE_STRINGS: # type: ignore
delimiters = self._BYTE_STRINGS # type: ignore
token_type = TokenType.BYTE_STRING
base = None
else:
@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
try:
self._add(token_type, f"{int(text, base)}")
except:
raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
return True
def _scan_identifier(self, identifier_end):
def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
def _scan_var(self):
def _scan_var(self) -> None:
while True:
char = self._peek.strip()
char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
def _extract_string(self, delimiter):
def _extract_string(self, delimiter: str) -> str:
text = ""
delim_size = len(delimiter)
while True:
if self._char == self.ESCAPE and self._peek == delimiter:
if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter
self._advance(2)
else:
@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char
text += self._char # type: ignore
self._advance()
return text

View file

@ -1,7 +1,14 @@
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
from sqlglot import expressions as exp
def unalias_group(expression):
def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
Replace references to select aliases in GROUP BY clauses.
@ -9,6 +16,12 @@ def unalias_group(expression):
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Args:
expression: the expression that will be transformed.
Returns:
The transformed expression.
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
@ -30,18 +43,19 @@ def unalias_group(expression):
return expression
def preprocess(transforms, to_sql):
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Create a new transform function that can be used a value in `Generator.TRANSFORMS`
to convert expressions to SQL.
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
Args:
transforms (list[(exp.Expression) -> exp.Expression]):
Sequence of transform functions. These will be called in order.
to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
Final transform that converts the resulting expression to a SQL string.
transforms: sequence of transform functions. These will be called in order.
to_sql: final transform that converts the resulting expression to a SQL string.
Returns:
(sqlglot.generator.Generator, exp.Expression) -> str:
Function that can be used as a generator transform.
"""
@ -54,12 +68,10 @@ def preprocess(transforms, to_sql):
return _to_sql
def delegate(attr):
def delegate(attr: str) -> t.Callable:
"""
Create a new method that delegates to `attr`.
This is useful for creating `Generator.TRANSFORMS` functions that delegate
to existing generator methods.
Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
functions that delegate to existing generator methods.
"""
def _transform(self, *args, **kwargs):

View file

@ -1,5 +1,26 @@
def new_trie(keywords):
trie = {}
import typing as t
key = t.Sequence[t.Hashable]
def new_trie(keywords: t.Iterable[key]) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
The trie is represented as a sequence of nested dictionaries keyed by either single character
strings, or by 0, which is used to designate that a keyword is in the trie.
Example:
>>> new_trie(["bla", "foo", "blab"])
{'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}}
Args:
keywords: the keywords to create the trie from.
Returns:
The trie corresponding to `keywords`.
"""
trie: t.Dict = {}
for key in keywords:
current = trie
@ -11,7 +32,28 @@ def new_trie(keywords):
return trie
def in_trie(trie, key):
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
"""
Checks whether a key is in a trie.
Examples:
>>> in_trie(new_trie(["cat"]), "bob")
(0, {'c': {'a': {'t': {0: True}}}})
>>> in_trie(new_trie(["cat"]), "ca")
(1, {'t': {0: True}})
>>> in_trie(new_trie(["cat"]), "cat")
(2, {0: True})
Args:
trie: the trie to be searched.
key: the target key.
Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
"""
if not key:
return (0, trie)

View file

@ -1,9 +1,9 @@
import sys
import typing as t
import unittest
import warnings
import sqlglot
from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
"Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
cls.spark = (
SparkSession.builder.master("local[*]")
.appName("Unit-tests")
.config(conf=config)
.getOrCreate()
)
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField(
"employee_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
cls.df_employee = cls.spark.createDataFrame(
data=employee_data, schema=cls.spark_employee_schema
)
cls.dfs_employee = cls.sqlglot.createDataFrame(
data=employee_data, schema=cls.sqlglot_employee_schema
)
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
cls.dfs_store = cls.sqlglot.createDataFrame(
data=store_data, schema=cls.sqlglot_store_schema
)
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField(
"district_id", sqlglotSparkTypes.IntegerType(), False
),
sqlglotSparkTypes.StructField(
"district_name", sqlglotSparkTypes.StringType(), False
),
sqlglotSparkTypes.StructField(
"manager_name", sqlglotSparkTypes.StringType(), False
),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
cls.df_district = cls.spark.createDataFrame(
data=district_data, schema=cls.spark_district_schema
)
cls.dfs_district = cls.sqlglot.createDataFrame(
data=district_data, schema=cls.sqlglot_district_schema
)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)

View file

@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
self.df_spark_employee["employee_id"],
F.col("df_employee.fname"),
self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
self.df_sqlglot_employee["employee_id"],
SF.col("dfs_employee.fname"),
self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
F.when(
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
F.lit("between 40 and 60"),
)
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
SF.when(
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
SF.lit("between 40 and 60"),
)
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
F.col("age") < F.lit(40), "less than 40"
)
F.when(
(F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
F.lit("between 40 and 60"),
).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
SF.col("age") < SF.lit(40), "less than 40"
)
SF.when(
(SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
SF.lit("between 40 and 60"),
).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
)
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
df_employee = self.df_spark_employee.where(
(F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
)
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
self.df_sqlglot_employee["age"] * SF.lit(0.5)
== self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="inner"
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="inner"
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
df_joined = self.df_spark_employee.select(
F.col("store_id"), F.col("fname"), F.col("lname")
).join(
self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
on=["store_id"],
how="inner",
)
dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
dfs_joined = self.df_sqlglot_employee.select(
SF.col("store_id"), SF.col("fname"), SF.col("lname")
).join(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
on=["store_id"],
how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
self.df_spark_store,
on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
self.df_sqlglot_store,
on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
df_joined = self.df_spark_employee.join(
self.df_spark_store, on=["store_id"], how="full_outer"
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store, on=["store_id"], how="full_outer"
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
self.df_employee.join(
self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
)
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
self.dfs_employee.join(
self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
)
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
self.df_spark_store, "store_id", "inner"
)
df = self.df_spark_employee.select(
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
).join(self.df_spark_store, "store_id", "inner")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
self.df_sqlglot_store, "store_id", "inner"
)
dfs = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
.unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
.unionAll(
self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
)
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
df = self.df_spark_employee.select(
F.col("employee_id"), F.col("fname"), F.col("lname")
).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
dfs = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("fname"), SF.col("lname")
).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales"))
.orderBy(F.col("district_id"))
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales"))
.orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
.orderBy(
F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
)
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
.orderBy(
SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
)
)
self.compare_spark_with_sqlglot(df, dfs)
@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
.orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
.orderBy(
F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
)
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
.orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
.orderBy(
SF.when(
SF.col("district_id") == SF.lit(1), SF.col("district_id")
).desc_nulls_first()
)
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
)
df_employee_duplicate = self.df_spark_employee.select(
F.col("employee_id"), F.col("store_id")
).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
)
df_store_duplicate = self.df_spark_store.select(
F.col("store_id"), F.col("district_id")
).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
)
dfs_employee_duplicate = self.df_sqlglot_employee.select(
SF.col("employee_id"), SF.col("store_id")
).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
)
dfs_store_duplicate = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("district_id")
).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
SF.lit(None),
SF.lit(1),
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
df = self.df_spark_employee.select(
F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
{37: 100}
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
{37: 100}
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
to_replace=37, value=100
)
dfs = self.df_sqlglot_employee.select(
SF.col("age"), SF.lit(37).alias("test_col")
).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
"first_name", "first_name_again"
)
dfs = self.df_sqlglot_employee.select(
SF.col("fname").alias("first_name")
).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
"age"
)
self.compare_spark_with_sqlglot(df, dfs)
@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
df_sqlglot_store_cols = self.df_sqlglot_store.select(
SF.col("store_id"), SF.col("store_name")
)
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)

View file

@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON
e.store_id = s.store_id
"""
df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
df = (
self.spark.sql(query)
.groupBy(F.col("store_id"))
.agg(F.countDistinct(F.col("employee_id")))
)
dfs = (
self.sqlglot.sql(query)
.groupBy(SF.col("store_id"))
.agg(SF.countDistinct(SF.col("employee_id")))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)

View file

@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
self.df_employee = self.spark.createDataFrame(
data=employee_data, schema=self.employee_schema
)
def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
def compare_sql(
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
expected_statements = (
[expected_statements] if isinstance(expected_statements, str) else expected_statements
)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)

View file

@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self):
self.assertEqual(
"cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
"cola = colb AND colc = cold",
((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
"cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
"cola = colb OR colc = cold",
((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
)
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
F.col("cola")
.between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
.sql(),
)
def test_over(self):

View file

@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
self.assertEqual(
["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
)
def test_cache(self):
df = self.df_employee.select("fname").cache()

View file

@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
)
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
)
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual(
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
"WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
col_no_slide.sql(),
)
def test_session_window(self):
@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
self.assertEqual(
"FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola")
@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
"cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
"cola",
lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
SF.length(y) - SF.length(x)
),
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
)
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
self.assertEqual(
"FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
)
col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
self.assertEqual(
"TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
)
def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0)
@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
col_custom_names = SF.filter(
"cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
)
self.assertEqual(
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
col_custom_names.sql(),
)
def test_zip_with(self):
@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
self.assertEqual(
"ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
)
def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
self.assertEqual(
"TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
)
def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v)

View file

@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
expected = (
"SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
)
expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
df.sql(pretty=False),
)
@mock.patch("sqlglot.schema", MappingSchema())
@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
expected = (
"INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
)
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):

View file

@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
self.assertEqual(
"map<int, string>",
types.MapType(types.IntegerType(), types.StringType()).simpleString(),
)
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())

View file

@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
rows_between_unbounded_start.sql(),
)
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
self.assertEqual(
"OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_end.sql(),
)
rows_between_unbounded_both = Window.rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
"OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_end.sql(),
)
range_between_unbounded_both = Window.rangeBetween(
Window.unboundedPreceding, Window.unboundedFollowing
)
self.assertEqual(
"OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
range_between_unbounded_both.sql(),
)

View file

@ -157,6 +157,14 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
"DIV(x, y)",
write={
"bigquery": "DIV(x, y)",
"duckdb": "CAST(x / y AS INT)",
},
)
self.validate_identity(
"SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)"
)
@ -284,4 +292,6 @@ class TestBigQuery(Validator):
"CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'"
)
self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)")
self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t")
self.validate_identity(
"CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t"
)

View file

@ -18,7 +18,6 @@ class TestClickhouse(Validator):
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST",
},
)
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
@ -31,3 +30,7 @@ class TestClickhouse(Validator):
"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
},
)
self.validate_all(
"SELECT x #! comment",
write={"": "SELECT x /* comment */"},
)

View file

@ -22,7 +22,8 @@ class TestDatabricks(Validator):
},
)
self.validate_all(
"SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
"SELECT DATEDIFF('end', 'start')",
write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"},
)
self.validate_all(
"SELECT DATE_ADD('2020-01-01', 1)",

View file

@ -1,20 +1,18 @@
import unittest
from sqlglot import (
Dialect,
Dialects,
ErrorLevel,
UnsupportedError,
parse_one,
transpile,
)
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
class Validator(unittest.TestCase):
dialect = None
def validate_identity(self, sql):
self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
def parse_one(self, sql):
return parse_one(sql, read=self.dialect)
def validate_identity(self, sql, write_sql=None):
expression = self.parse_one(sql)
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
expression = parse_one(sql, read=self.dialect)
expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
parse_one(read_sql, read_dialect).sql(
self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
),
sql,
)
@ -83,10 +83,6 @@ class TestDialect(Validator):
)
self.validate_all(
"CAST(a AS BINARY(4))",
read={
"presto": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS VARBINARY(4))",
},
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
@ -103,6 +99,24 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS BINARY(4))",
},
)
self.validate_all(
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS VARBINARY(4))",
"clickhouse": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS VARBINARY(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
"postgres": "CAST(a AS BYTEA(4))",
"presto": "CAST(a AS VARBINARY(4))",
"redshift": "CAST(a AS VARBYTE(4))",
"snowflake": "CAST(a AS VARBINARY(4))",
"sqlite": "CAST(a AS BLOB(4))",
"spark": "CAST(a AS BINARY(4))",
"starrocks": "CAST(a AS VARBINARY(4))",
},
)
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
@ -472,45 +486,57 @@ class TestDialect(Validator):
},
)
self.validate_all(
"DATE_TRUNC(x, 'day')",
"DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
"starrocks": "DATE(x)",
},
)
self.validate_all(
"DATE_TRUNC(x, 'week')",
"DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'month')",
"DATE_TRUNC('month', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'quarter')",
"DATE_TRUNC('quarter', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'year')",
"DATE_TRUNC('year', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
"DATE_TRUNC(x, 'millenium')",
"DATE_TRUNC('millenium', x)",
write={
"mysql": UnsupportedError,
"starrocks": UnsupportedError,
},
)
self.validate_all(
"DATE_TRUNC('year', x)",
read={
"starrocks": "DATE_TRUNC('year', x)",
},
write={
"starrocks": "DATE_TRUNC('year', x)",
},
)
self.validate_all(
"DATE_TRUNC(x, year)",
read={
"bigquery": "DATE_TRUNC(x, year)",
},
write={
"bigquery": "DATE_TRUNC(x, year)",
},
)
self.validate_all(
@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
},
)
self.validate_all(
"TIMESTAMP('2022-01-01')",
write={
"mysql": "TIMESTAMP('2022-01-01')",
"starrocks": "TIMESTAMP('2022-01-01')",
"hive": "TIMESTAMP('2022-01-01')",
},
)
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
@ -1002,7 +1044,10 @@ class TestDialect(Validator):
)
def test_limit(self):
self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
)
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
def test_nullsafe_eq(self):
self.validate_all(
"SELECT a IS NOT DISTINCT FROM b",
read={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
write={
"mysql": "SELECT a <=> b",
"postgres": "SELECT a IS NOT DISTINCT FROM b",
},
)
def test_nullsafe_neq(self):
self.validate_all(
"SELECT a IS DISTINCT FROM b",
read={
"postgres": "SELECT a IS DISTINCT FROM b",
},
write={
"mysql": "SELECT NOT a <=> b",
"postgres": "SELECT a IS DISTINCT FROM b",
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 /* arbitrary content,,, until end-of-line */",
read={
"mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
"bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
"clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
},
)
self.validate_all(
"""/* comment1 */
SELECT
x, -- comment2
y -- comment3""",
read={
"mysql": """SELECT # comment1
x, # comment2
y # comment3""",
"bigquery": """SELECT # comment1
x, # comment2
y # comment3""",
"clickhouse": """SELECT # comment1
x, # comment2
y # comment3""",
},
pretty=True,
)

View file

@ -1,3 +1,4 @@
from sqlglot import expressions as exp
from tests.dialects.test_dialect import Validator
@ -20,6 +21,52 @@ class TestMySQL(Validator):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")
# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
self.validate_identity("SET GLOBAL max_connections = 1000")
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
self.validate_identity("SET PERSIST max_connections = 1000")
self.validate_identity("SET @@PERSIST.max_connections = 1000")
self.validate_identity("SET PERSIST_ONLY back_log = 100")
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
self.validate_identity(
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
)
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
self.validate_identity("SET CHARACTER SET 'utf8'")
self.validate_identity("SET CHARACTER SET utf8")
self.validate_identity("SET CHARACTER SET DEFAULT")
self.validate_identity("SET NAMES 'utf8'")
self.validate_identity("SET NAMES DEFAULT")
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
def test_escape(self):
self.validate_all(
r"'a \' b '' '",
write={
"mysql": r"'a '' b '' '",
"spark": r"'a \' b \' '",
},
)
def test_introducers(self):
self.validate_all(
@ -115,14 +162,6 @@ class TestMySQL(Validator):
},
)
def test_hash_comments(self):
self.validate_all(
"SELECT 1 # arbitrary content,,, until end-of-line",
write={
"mysql": "SELECT 1",
},
)
def test_mysql(self):
self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
@ -174,3 +213,242 @@ COMMENT='客户账户表'"""
},
pretty=True,
)
def test_show_simple(self):
for key, write_key in [
("BINARY LOGS", "BINARY LOGS"),
("MASTER LOGS", "BINARY LOGS"),
("STORAGE ENGINES", "ENGINES"),
("ENGINES", "ENGINES"),
("EVENTS", "EVENTS"),
("MASTER STATUS", "MASTER STATUS"),
("PLUGINS", "PLUGINS"),
("PRIVILEGES", "PRIVILEGES"),
("PROFILES", "PROFILES"),
("REPLICAS", "REPLICAS"),
("SLAVE HOSTS", "REPLICAS"),
]:
show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, write_key)
def test_show_events(self):
for key in ["BINLOG", "RELAYLOG"]:
show = self.validate_identity(f"SHOW {key} EVENTS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, f"{key} EVENTS")
show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3")
self.assertEqual(show.text("log"), "log")
self.assertEqual(show.text("position"), "1")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1")
self.assertEqual(show.text("limit"), "1")
self.assertIsNone(show.args.get("offset"))
def test_show_like_or_where(self):
for key, write_key in [
("CHARSET", "CHARACTER SET"),
("CHARACTER SET", "CHARACTER SET"),
("COLLATION", "COLLATION"),
("DATABASES", "DATABASES"),
("FUNCTION STATUS", "FUNCTION STATUS"),
("PROCEDURE STATUS", "PROCEDURE STATUS"),
("GLOBAL STATUS", "GLOBAL STATUS"),
("SESSION STATUS", "STATUS"),
("STATUS", "STATUS"),
("GLOBAL VARIABLES", "GLOBAL VARIABLES"),
("SESSION VARIABLES", "VARIABLES"),
("VARIABLES", "VARIABLES"),
]:
expected_name = write_key.strip("GLOBAL").strip()
template = "SHOW {}"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, expected_name)
template = "SHOW {} LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
template = "SHOW {} WHERE Column_name LIKE '%foo%'"
show = self.validate_identity(template.format(key), template.format(write_key))
self.assertIsInstance(show, exp.Show)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_columns(self):
show = self.validate_identity("SHOW COLUMNS FROM tbl_name")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "COLUMNS")
self.assertEqual(show.text("target"), "tbl_name")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.text("target"), "tbl_name")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_show_name(self):
for key in [
"CREATE DATABASE",
"CREATE EVENT",
"CREATE FUNCTION",
"CREATE PROCEDURE",
"CREATE TABLE",
"CREATE TRIGGER",
"CREATE VIEW",
"FUNCTION CODE",
"PROCEDURE CODE",
]:
show = self.validate_identity(f"SHOW {key} foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
self.assertEqual(show.text("target"), "foo")
def test_show_grants(self):
show = self.validate_identity(f"SHOW GRANTS FOR foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "GRANTS")
self.assertEqual(show.text("target"), "foo")
def test_show_engine(self):
show = self.validate_identity("SHOW ENGINE foo STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertFalse(show.args["mutex"])
show = self.validate_identity("SHOW ENGINE foo MUTEX")
self.assertEqual(show.name, "ENGINE")
self.assertEqual(show.text("target"), "foo")
self.assertTrue(show.args["mutex"])
def test_show_errors(self):
for key in ["ERRORS", "WARNINGS"]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} LIMIT 2, 3")
self.assertEqual(show.text("limit"), "3")
self.assertEqual(show.text("offset"), "2")
def test_show_index(self):
show = self.validate_identity("SHOW INDEX FROM foo")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "INDEX")
self.assertEqual(show.text("target"), "foo")
show = self.validate_identity("SHOW INDEX FROM foo FROM bar")
self.assertEqual(show.text("db"), "bar")
def test_show_db_like_or_where_sql(self):
for key in [
"OPEN TABLES",
"TABLE STATUS",
"TRIGGERS",
]:
show = self.validate_identity(f"SHOW {key}")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, key)
show = self.validate_identity(f"SHOW {key} FROM db_name")
self.assertEqual(show.name, key)
self.assertEqual(show.text("db"), "db_name")
show = self.validate_identity(f"SHOW {key} LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'")
self.assertEqual(show.name, key)
self.assertIsInstance(show.args["where"], exp.Where)
self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'")
def test_show_processlist(self):
show = self.validate_identity("SHOW PROCESSLIST")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROCESSLIST")
self.assertFalse(show.args["full"])
show = self.validate_identity("SHOW FULL PROCESSLIST")
self.assertEqual(show.name, "PROCESSLIST")
self.assertTrue(show.args["full"])
def test_show_profile(self):
show = self.validate_identity("SHOW PROFILE")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "PROFILE")
show = self.validate_identity("SHOW PROFILE BLOCK IO")
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
show = self.validate_identity(
"SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3"
)
self.assertEqual(show.args["types"][0].name, "BLOCK IO")
self.assertEqual(show.args["types"][1].name, "PAGE FAULTS")
self.assertEqual(show.text("query"), "1")
self.assertEqual(show.text("offset"), "2")
self.assertEqual(show.text("limit"), "3")
def test_show_replica_status(self):
show = self.validate_identity("SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "REPLICA STATUS")
show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name")
self.assertEqual(show.text("channel"), "channel_name")
def test_show_tables(self):
show = self.validate_identity("SHOW TABLES")
self.assertIsInstance(show, exp.Show)
self.assertEqual(show.name, "TABLES")
show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'")
self.assertTrue(show.args["full"])
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")
def test_set_variable(self):
cmd = self.parse_one("SET SESSION x = 1")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "SESSION")
self.assertIsInstance(item.this, exp.EQ)
self.assertEqual(item.this.left.name, "x")
self.assertEqual(item.this.right.name, "1")
cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "")
self.assertIsInstance(item.this, exp.EQ)
self.assertIsInstance(item.this.left, exp.SessionParameter)
self.assertIsInstance(item.this.right, exp.SessionParameter)
cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "NAMES")
self.assertEqual(item.name, "charset_name")
self.assertEqual(item.text("collate"), "collation_name")
cmd = self.parse_one("SET CHARSET DEFAULT")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "CHARACTER SET")
self.assertEqual(item.this.name, "DEFAULT")
cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)

View file

@ -8,7 +8,9 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_all(
"CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)",
write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"},
write={
"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"
},
)
self.validate_all(
"CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)",
@ -59,15 +61,27 @@ class TestPostgres(Validator):
def test_postgres(self):
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END")
self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')')
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
"SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END"
)
self.validate_identity(
'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')'
)
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')")
self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')")
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity(
"SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')"
)
self.validate_identity(
"SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))"
)
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_identity(
"SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')"
)
self.validate_identity("COMMENT ON TABLE mytable IS 'this'")
self.validate_identity("SELECT e'\\xDEADBEEF'")
self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)")
@ -75,7 +89,7 @@ class TestPostgres(Validator):
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
write={
"duckdb": "CREATE TABLE x (a UUID, b BINARY)",
"duckdb": "CREATE TABLE x (a UUID, b VARBINARY)",
"presto": "CREATE TABLE x (a UUID, b VARBINARY)",
"hive": "CREATE TABLE x (a UUID, b BINARY)",
"spark": "CREATE TABLE x (a UUID, b BINARY)",
@ -153,7 +167,9 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss",
read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"},
read={
"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
},
)
self.validate_all(
"SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL",
@ -169,11 +185,15 @@ class TestPostgres(Validator):
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"
},
)
self.validate_all(
"SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL",
read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"},
read={
"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"
},
)
self.validate_all(
"'[1,2,3]'::json->2",
@ -184,7 +204,8 @@ class TestPostgres(Validator):
write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}
"""'{"x": {"y": 1}}'::json->'x'->'y'""",
write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""},
)
self.validate_all(
"""'{"x": {"y": 1}}'::json->'x'::json->'y'""",

View file

@ -61,4 +61,6 @@ class TestRedshift(Validator):
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'")
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)

View file

@ -336,7 +336,8 @@ class TestSnowflake(Validator):
def test_table_literal(self):
# All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
self.validate_all(
r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
r"""SELECT * FROM TABLE('MYTABLE')""",
write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""},
)
self.validate_all(
@ -352,15 +353,123 @@ class TestSnowflake(Validator):
write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
)
self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR)""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}
)
self.validate_all(
r"""SELECT * FROM TABLE(:BINDING)""",
write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""},
)
self.validate_all(
r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
)
def test_flatten(self):
self.validate_all(
"""
select
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
f.value::varchar as operator
from cs.telescope.dag_report,
table(flatten(input=>split(operators, ','))) f
""",
write={
"snowflake": """SELECT
dag_report.acct_id,
dag_report.report_date,
dag_report.report_uuid,
dag_report.airflow_name,
dag_report.dag_id,
CAST(f.value AS VARCHAR) AS operator
FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f"""
},
pretty=True,
)
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(
"SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f",
write={
"snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f"
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""",
write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f"""
},
)
self.validate_all(
"""SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""",
write={
"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f"""
},
)
self.validate_all(
"""
SELECT id as "ID",
f.value AS "Contact",
f1.value:type AS "Type",
f1.value:content AS "Details"
FROM persons p,
lateral flatten(input => p.c, path => 'contact') f,
lateral flatten(input => f.value:business) f1
""",
write={
"snowflake": """SELECT
id AS "ID",
f.value AS "Contact",
f1.value['type'] AS "Type",
f1.value['content'] AS "Details"
FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""",
},
pretty=True,
)

View file

@ -284,4 +284,6 @@ TBLPROPERTIES (
)
def test_iif(self):
self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"})
self.validate_all(
"SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
)

View file

@ -6,3 +6,6 @@ class TestMySQL(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")

View file

@ -278,12 +278,19 @@ class TestTSQL(Validator):
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
self.validate_all(
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
"SELECT DATEADD(year, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"},
)
self.validate_all(
"SELECT DATEADD(qq, 1, '2017/08/25')",
write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"},
)
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
self.validate_all(
"SELECT DATEADD(wk, 1, '2017/08/25')",
write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
write={
"spark": "SELECT DATE_ADD('2017/08/25', 7)",
"databricks": "SELECT DATEADD(week, 1, '2017/08/25')",
},
)
def test_date_diff(self):
@ -370,13 +377,21 @@ class TestTSQL(Validator):
"SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
)
self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
self.validate_all(
"SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}
)
self.validate_all(
"SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
"SELECT FORMAT(date_col, 'dd.mm.yyyy')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"},
)
self.validate_all(
"SELECT FORMAT(date_col, 'm')",
write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"},
)
self.validate_all(
"SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}
)
self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})

View file

@ -523,6 +523,8 @@ DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
USE db
ROLLBACK
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@ -569,3 +571,13 @@ SELECT * FROM (tbl1 LEFT JOIN tbl2 ON 1 = 1)
SELECT * FROM (tbl1 JOIN tbl2 JOIN tbl3)
SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo)
SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl)
SELECT CAST(x AS INT) /* comment */ FROM foo
SELECT a /* x */, b /* x */
SELECT * FROM foo /* x */, bla /* x */
SELECT 1 /* comment */ + 1
SELECT 1 /* c1 */ + 2 /* c2 */
SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */

View file

@ -104,6 +104,16 @@ SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
# dialect: starrocks
# execute: false
SELECT DATE_TRUNC('week', a) AS a FROM x;
SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x;
# dialect: bigquery
# execute: false
SELECT DATE_TRUNC(a, MONTH) AS a FROM x;
SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x;
--------------------------------------
-- Derived tables
--------------------------------------

View file

@ -79,6 +79,15 @@ NULL;
NULL = NULL;
NULL;
NULL <=> NULL;
TRUE;
a IS NOT DISTINCT FROM a;
TRUE;
NULL IS DISTINCT FROM NULL;
FALSE;
NOT (NOT TRUE);
TRUE;

View file

@ -287,3 +287,31 @@ SELECT
"fffffff"
)
);
/*
multi
line
comment
*/
SELECT * FROM foo;
/*
multi
line
comment
*/
SELECT
*
FROM foo;
SELECT x FROM a.b.c /*x*/, e.f.g /*x*/;
SELECT
x
FROM a.b.c /* x */, e.f.g /* x */;
SELECT x FROM (SELECT * FROM bla /*x*/WHERE id = 1) /*x*/;
SELECT
x
FROM (
SELECT
*
FROM bla /* x */
WHERE
id = 1
) /* x */;

View file

@ -100,15 +100,21 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer"),
lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2",
),
(
lambda: select("x").from_("tbl").join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
lambda: select("x")
.from_("tbl")
.join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"),
"SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo",
),
(
lambda: select("x").from_("tbl").join(select("y").from_("tbl2"), join_type="left outer"),
lambda: select("x")
.from_("tbl")
.join(select("y").from_("tbl2"), join_type="left outer"),
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)",
),
(
@ -131,7 +137,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased",
),
(
lambda: select("x").from_("tbl").join(parse_one("left join x", into=exp.Join), on="a=b"),
lambda: select("x")
.from_("tbl")
.join(parse_one("left join x", into=exp.Join), on="a=b"),
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
@ -139,7 +147,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM tbl LEFT JOIN x ON a = b",
),
(
lambda: select("x").from_("tbl").join("select b from tbl2", on="a=b", join_type="left"),
lambda: select("x")
.from_("tbl")
.join("select b from tbl2", on="a=b", join_type="left"),
"SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b",
),
(
@ -162,7 +172,10 @@ class TestBuild(unittest.TestCase):
(
lambda: select("x", "y", "z")
.from_("merged_df")
.join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]),
.join(
"vte_diagnosis_df",
using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")],
),
"SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)",
),
(
@ -222,7 +235,10 @@ class TestBuild(unittest.TestCase):
"SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a",
),
(
lambda: select("x", "y", "z", "a").from_("tbl").cluster_by("x, y", "z").cluster_by("a"),
lambda: select("x", "y", "z", "a")
.from_("tbl")
.cluster_by("x, y", "z")
.cluster_by("a"),
"SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a",
),
(
@ -239,7 +255,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
lambda: select("x")
.from_("tbl")
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
@ -247,7 +265,9 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
lambda: select("x")
.from_("tbl")
.with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")),
"WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl",
),
(
@ -258,7 +278,10 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl",
),
(
lambda: select("x").from_("tbl").with_("tbl", as_=select("x", "y").from_("tbl2")).select("y"),
lambda: select("x")
.from_("tbl")
.with_("tbl", as_=select("x", "y").from_("tbl2"))
.select("y"),
"WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl",
),
(
@ -266,35 +289,59 @@ class TestBuild(unittest.TestCase):
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").group_by("x"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.group_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").order_by("x"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.order_by("x"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").limit(10),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.limit(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").offset(10),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.offset(10),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").join("tbl3"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.join("tbl3"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").distinct(),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.distinct(),
"WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").where("x > 10"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.where("x > 10"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10",
),
(
lambda: select("x").with_("tbl", as_=select("x").from_("tbl2")).from_("tbl").having("x > 20"),
lambda: select("x")
.with_("tbl", as_=select("x").from_("tbl2"))
.from_("tbl")
.having("x > 20"),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20",
),
(lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"),
@ -354,7 +401,9 @@ class TestBuild(unittest.TestCase):
"SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0",
),
(
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select("x"),
lambda: exp.subquery("select x from tbl UNION select x from bar", "unioned").select(
"x"
),
"SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned",
),
(

View file

@ -33,7 +33,10 @@ class TestExecutor(unittest.TestCase):
)
cls.cache = {}
cls.sqls = [(sql, expected) for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")]
cls.sqls = [
(sql, expected)
for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")
]
@classmethod
def tearDownClass(cls):
@ -63,7 +66,9 @@ class TestExecutor(unittest.TestCase):
def test_execute_tpch(self):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}")
return parse_one(
f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
)
return expression
for sql, _ in self.sqls[0:3]:

View file

@ -30,7 +30,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)"))
self.assertEqual(exp.Table(pivots=[]), exp.Table())
self.assertNotEqual(exp.Table(pivots=[None]), exp.Table())
self.assertEqual(exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False))
self.assertEqual(
exp.DataType.build("int"), exp.DataType(this=exp.DataType.Type.INT, nested=False)
)
def test_find(self):
expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y")
@ -89,7 +91,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsNone(column.find_ancestor(exp.Join))
def test_alias_or_name(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "e", "*", "zz", "z"],
@ -166,7 +170,9 @@ class TestExpressions(unittest.TestCase):
"SELECT * FROM foo WHERE ? > 100",
)
self.assertEqual(
exp.replace_placeholders(parse_one("select * from :name WHERE ? > 100"), another_name="bla").sql(),
exp.replace_placeholders(
parse_one("select * from :name WHERE ? > 100"), another_name="bla"
).sql(),
"SELECT * FROM :name WHERE ? > 100",
)
self.assertEqual(
@ -183,7 +189,9 @@ class TestExpressions(unittest.TestCase):
)
def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
expression = parse_one(
"SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz"
)
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
expression = parse_one(
@ -367,7 +375,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(len(list(expression.walk())), 9)
self.assertEqual(len(list(expression.walk(bfs=False))), 9)
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()))
self.assertTrue(all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)))
self.assertTrue(
all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False))
)
def test_functions(self):
self.assertIsInstance(parse_one("ABS(a)"), exp.Abs)
@ -512,14 +522,21 @@ class TestExpressions(unittest.TestCase):
),
exp.Properties(
expressions=[
exp.FileFormatProperty(this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")),
exp.FileFormatProperty(
this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet")
),
exp.PartitionedByProperty(
this=exp.Literal.string("PARTITIONED_BY"),
value=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]),
value=exp.Tuple(
expressions=[exp.to_identifier("a"), exp.to_identifier("b")]
),
),
exp.AnonymousProperty(
this=exp.Literal.string("custom"), value=exp.Literal.number(1)
),
exp.AnonymousProperty(this=exp.Literal.string("custom"), value=exp.Literal.number(1)),
exp.TableFormatProperty(
this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format")
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
@ -538,7 +555,10 @@ class TestExpressions(unittest.TestCase):
((1, "2", None), "(1, '2', NULL)"),
([1, "2", None], "ARRAY(1, '2', NULL)"),
({"x": None}, "MAP('x', NULL)"),
(datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"),
(
datetime.datetime(2022, 10, 1, 1, 1, 1),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')",
),
(
datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
"TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')",
@ -548,30 +568,48 @@ class TestExpressions(unittest.TestCase):
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)
def test_annotation_alias(self):
sql = "SELECT a, b AS B, c # comment, d AS D # another_comment FROM foo"
def test_comment_alias(self):
sql = """
SELECT
a,
b AS B,
c, /*comment*/
d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo
"""
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
["a", "B", "c", "D"],
["a", "B", "c", "D", "x"],
)
self.assertEqual(expression.sql(), "SELECT a, b AS B, c, d AS D")
self.assertEqual(expression.expressions[2].name, "comment")
self.assertEqual(
expression.sql(pretty=True, annotations=False),
expression.sql(),
"SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
)
self.assertEqual(
expression.sql(comments=False),
"SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
)
self.assertEqual(
expression.sql(pretty=True, comments=False),
"""SELECT
a,
b AS B,
c,
d AS D""",
d AS D,
CAST(x AS INT)
FROM foo""",
)
self.assertEqual(
expression.sql(pretty=True),
"""SELECT
a,
b AS B,
c # comment,
d AS D # another_comment FROM foo""",
c, -- comment
d AS D, -- another comment
CAST(x AS INT) -- final comment
FROM foo""",
)
def test_to_table(self):
@ -605,5 +643,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(expression, exp.Union)
self.assertEqual(expression.named_selects, ["cola", "colb"])
self.assertEqual(
expression.selects, [exp.Column(this=exp.to_identifier("cola")), exp.Column(this=exp.to_identifier("colb"))]
expression.selects,
[
exp.Column(this=exp.to_identifier("cola")),
exp.Column(this=exp.to_identifier("colb")),
],
)

View file

@ -67,7 +67,9 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
for i, (meta, sql, expected) in enumerate(
load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
):
title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
@ -90,7 +92,9 @@ class TestOptimizer(unittest.TestCase):
if string_to_bool(should_execute):
with self.subTest(f"(execute) {title}"):
df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
df1 = self.conn.execute(
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
).df()
df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
assert_frame_equal(df1, df2)
@ -268,7 +272,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
self.assertEqual(scopes[2].expression.sql(), "(VALUES (1, 'test')) AS tab(cola, colb)")
self.assertEqual(
scopes[3].expression.sql(), "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)"
scopes[3].expression.sql(),
"SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)",
)
self.assertEqual(scopes[4].expression.sql(), "SELECT y.c AS b FROM y")
self.assertEqual(scopes[5].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b")
@ -287,7 +292,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that we can walk in scope from an arbitrary node
self.assertEqual(
{node.sql() for node, *_ in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column)},
{
node.sql()
for node, *_ in walk_in_scope(expression.find(exp.Where))
if isinstance(node, exp.Column)
},
{"s.b"},
)
@ -324,7 +333,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
@ -384,7 +395,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
self.assertEqual(
expression.expressions[0].type, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
@ -396,7 +409,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
self.assertEqual(
cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
@ -405,7 +420,9 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
self.assertEqual(d.this.expressions[0].this.type, t)
def test_function_annotation(self):
@ -421,6 +438,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
@ -431,8 +461,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr = concat_expr_alias.this
self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
self.assertEqual(
concat_expr.right.type, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this

View file

@ -23,8 +23,6 @@ class TestParser(unittest.TestCase):
def test_float(self):
self.assertEqual(parse_one(".2"), parse_one("0.2"))
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
def test_table(self):
tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
@ -33,7 +31,9 @@ class TestParser(unittest.TestCase):
def test_select(self):
self.assertIsNotNone(parse_one("select 1 natural"))
self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
self.assertIsNotNone(parse_one("select * from x where a = (select 1) order by x.y").args["order"])
self.assertIsNotNone(
parse_one("select * from x where a = (select 1) order by x.y").args["order"]
)
self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
self.assertEqual(
parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
@ -125,26 +125,70 @@ class TestParser(unittest.TestCase):
def test_var(self):
self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")
def test_annotations(self):
def test_comments(self):
expression = parse_one(
"""
SELECT
a #annotation1,
b as B #annotation2:testing ,
"test#annotation",c#annotation3, d #annotation4,
e #,
f # space
--comment1
SELECT /* this won't be used */
a, --comment2
b as B, --comment3:testing
"test--annotation",
c, --comment4 --foo
e, --
f -- space
FROM foo
"""
)
assert expression.expressions[0].name == "annotation1"
assert expression.expressions[1].name == "annotation2:testing"
assert expression.expressions[2].name == "test#annotation"
assert expression.expressions[3].name == "annotation3"
assert expression.expressions[4].name == "annotation4"
assert expression.expressions[5].name == ""
assert expression.expressions[6].name == "space"
self.assertEqual(expression.comment, "comment1")
self.assertEqual(expression.expressions[0].comment, "comment2")
self.assertEqual(expression.expressions[1].comment, "comment3:testing")
self.assertEqual(expression.expressions[2].comment, None)
self.assertEqual(expression.expressions[3].comment, "comment4 --foo")
self.assertEqual(expression.expressions[4].comment, "")
self.assertEqual(expression.expressions[5].comment, " space")
def test_type_literals(self):
self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
self.assertEqual(
parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
)
self.assertEqual(
parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
)
self.assertEqual(
parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ)",
)
self.assertEqual(
parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP)",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
)
self.assertEqual(
parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
"CAST('2022-01-01' AS TIMESTAMP(1))",
)
self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """CAST('{"x":"y"}' AS JSON)""")
self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
self.assertIsInstance(parse_one("map.x"), exp.Column)
def test_pretty_config_override(self):
self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")

View file

@ -1,281 +1,141 @@
import unittest
from sqlglot import table
from sqlglot.dataframe.sql import types as df_types
from sqlglot import exp, to_table
from sqlglot.errors import SchemaError
from sqlglot.schema import MappingSchema, ensure_schema
class TestSchema(unittest.TestCase):
def assert_column_names(self, schema, *table_results):
for table, result in table_results:
with self.subTest(f"{table} -> {result}"):
self.assertEqual(schema.column_names(to_table(table)), result)
def assert_column_names_raises(self, schema, *tables):
for table in tables:
with self.subTest(table):
with self.assertRaises(SchemaError):
schema.column_names(to_table(table))
def test_schema(self):
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
},
"y": {
"b": "uint64",
"c": "uint64",
},
},
)
self.assertEqual(
schema.column_names(
table(
self.assert_column_names(
schema,
("x", ["a"]),
("y", ["b", "c"]),
("z.x", ["a"]),
("z.x.y", ["b", "c"]),
)
self.assert_column_names_raises(
schema,
"z",
"z.z",
"z.z.z",
)
def test_schema_db(self):
schema = ensure_schema(
{
"d1": {
"x": {
"a": "uint64",
},
"y": {
"b": "uint64",
},
},
"d2": {
"x": {
"c": "uint64",
},
},
},
)
self.assert_column_names(
schema,
("d1.x", ["a"]),
("d2.x", ["c"]),
("y", ["b"]),
("d1.y", ["b"]),
("z.d1.y", ["b"]),
)
self.assert_column_names_raises(
schema,
"x",
)
),
["a"],
)
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x2"))
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y"), {"b": "string"})
schema_with_y = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
}
self.assertEqual(schema.schema, schema_with_y)
new_schema = schema.copy()
new_schema.add_table(table("z"), {"c": "string"})
self.assertEqual(schema.schema, schema_with_y)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"z": {
"c": "string",
},
},
)
schema.add_table(table("m"), {"d": "string"})
schema.add_table(table("n"), {"e": "string"})
schema_with_m_n = {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
}
self.assertEqual(schema.schema, schema_with_m_n)
new_schema = schema.copy()
new_schema.add_table(table("o"), {"f": "string"})
new_schema.add_table(table("p"), {"g": "string"})
self.assertEqual(schema.schema, schema_with_m_n)
self.assertEqual(
new_schema.schema,
{
"x": {
"a": "uint64",
},
"y": {
"b": "string",
},
"m": {
"d": "string",
},
"n": {
"e": "string",
},
"o": {
"f": "string",
},
"p": {
"g": "string",
},
},
"z.x",
"z.y",
)
def test_schema_catalog(self):
schema = ensure_schema(
{
"db": {
"x": {
"a": "uint64",
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("y"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("y", db="db", catalog="c"), {"b": "string"})
schema.add_table(table("y", db="db"), {"b": "string"})
self.assertEqual(
schema.schema,
{
"db": {
"c1": {
"d1": {
"x": {
"a": "uint64",
},
"y": {
"b": "string",
"b": "uint64",
},
}
},
)
schema = ensure_schema(
{
"c": {
"db": {
"x": {
"a": "uint64",
}
}
}
}
)
self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db"))
with self.assertRaises(ValueError):
schema.column_names(table("x"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db", catalog="c2"))
with self.assertRaises(ValueError):
schema.column_names(table("x", db="db2"))
with self.assertRaises(ValueError):
schema.column_names(table("x2", db="db"))
with self.assertRaises(ValueError):
schema.add_table(table("x"), {"b": "string"})
with self.assertRaises(ValueError):
schema.add_table(table("x", db="db"), {"b": "string"})
schema.add_table(table("y", db="db", catalog="c"), {"a": "string", "b": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
}
}
},
)
schema.add_table(table("z", db="db2", catalog="c"), {"c": "string", "d": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
"c": "uint64",
},
}
},
)
schema.add_table(table("m", db="db2", catalog="c2"), {"e": "string", "f": "int"})
self.assertEqual(
schema.schema,
{
"c": {
"db": {
"x": {
"a": "uint64",
},
"y": {
"a": "string",
"b": "int",
},
},
"db2": {
"z": {
"c": "string",
"d": "int",
}
},
},
"c2": {
"db2": {
"m": {
"e": "string",
"f": "int",
}
}
},
},
)
schema = ensure_schema(
{
"x": {
"a": "uint64",
}
}
)
self.assertEqual(schema.column_names(table("x")), ["a"])
schema = MappingSchema()
schema.add_table(table("x"), {"a": "string"})
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
}
},
)
schema.add_table(table("y"), df_types.StructType([df_types.StructField("b", df_types.StringType())]))
self.assertEqual(
schema.schema,
{
"x": {
"a": "string",
},
"d1": {
"y": {
"b": "string",
"d": "uint64",
},
"z": {
"e": "uint64",
},
},
"d2": {
"z": {
"f": "uint64",
},
},
},
}
)
self.assert_column_names(
schema,
("x", ["a"]),
("d1.x", ["a"]),
("c1.d1.x", ["a"]),
("c1.d1.y", ["b"]),
("c1.d1.z", ["c"]),
("c2.d1.y", ["d"]),
("c2.d1.z", ["e"]),
("d2.z", ["f"]),
("c2.d2.z", ["f"]),
)
self.assert_column_names_raises(
schema,
"q",
"d2.x",
"y",
"z",
"d1.y",
"d1.z",
"a.b.c",
)
def test_schema_add_table_with_and_without_mapping(self):
@ -288,3 +148,34 @@ class TestSchema(unittest.TestCase):
self.assertEqual(schema.column_names("test"), ["x", "y"])
schema.add_table("test")
self.assertEqual(schema.column_names("test"), ["x", "y"])
def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}})
self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
)
self.assertEqual(
schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
exp.DataType.Type.VARCHAR,
)

18
tests/test_tokens.py Normal file
View file

@ -0,0 +1,18 @@
import unittest
from sqlglot.tokens import Tokenizer
class TestTokens(unittest.TestCase):
def test_comment_attachment(self):
tokenizer = Tokenizer()
sql_comment = [
("/*comment*/ foo", "comment"),
("/*comment*/ foo --test", "comment"),
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
]
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment)

View file

@ -49,6 +49,12 @@ class TestTranspile(unittest.TestCase):
leading_comma=True,
pretty=True,
)
self.validate(
"SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ",
"SELECT\n FOO -- x\n , BAR -- y\n , BAZ",
leading_comma=True,
pretty=True,
)
# without pretty, this should be a no-op
self.validate(
"SELECT FOO, BAR, BAZ",
@ -63,24 +69,61 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo")
self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo")
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")
self.validate(
"SELECT 1 /* inline */ FROM foo -- comment",
"SELECT 1 /* inline */ FROM foo /* comment */",
)
self.validate(
"SELECT FUN(x) /*x*/, [1,2,3] /*y*/", "SELECT FUN(x) /* x */, ARRAY(1, 2, 3) /* y */"
)
self.validate(
"""
SELECT 1 -- comment
FROM foo -- comment
""",
"SELECT 1 FROM foo",
"SELECT 1 /* comment */ FROM foo /* comment */",
)
self.validate(
"""
SELECT 1 /* big comment
like this */
FROM foo -- comment
""",
"SELECT 1 FROM foo",
"""SELECT 1 /* big comment
like this */ FROM foo /* comment */""",
)
self.validate(
"select x from foo -- x",
"SELECT x FROM foo /* x */",
)
self.validate(
"""
/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), # comment 3
y -- comment 4
FROM
bar /* comment 5 */,
tbl # comment 6
""",
"""/* multi
line
comment
*/
SELECT
tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */,
CAST(x AS INT), -- comment 3
y -- comment 4
FROM bar /* comment 5 */, tbl /* comment 6 */""",
read="mysql",
pretty=True,
)
def test_types(self):
@ -146,6 +189,16 @@ class TestTranspile(unittest.TestCase):
def test_ignore_nulls(self):
self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)")
def test_with(self):
self.validate(
"WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
self.validate(
"WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *",
"WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *",
)
def test_time(self):
self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)")
self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)")