1
0
Fork 0

Adding upstream version 25.26.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:55:50 +01:00
parent 7af32ea9ec
commit dfac4c492f
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
117 changed files with 49296 additions and 47316 deletions

View file

@ -1,6 +1,72 @@
Changelog Changelog
========= =========
## [v25.25.1] - 2024-10-15
### :bug: Bug Fixes
- [`e6567ae`](https://github.com/tobymao/sqlglot/commit/e6567ae11650834874808a844a19836fbb9ee753) - small overload fix for ensure list taking None *(PR [#4248](https://github.com/tobymao/sqlglot/pull/4248) by [@benfdking](https://github.com/benfdking))*
## [v25.25.0] - 2024-10-14
### :boom: BREAKING CHANGES
- due to [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*:
refactor set operation builders so they can work with N expressions (#4226)
- due to [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*:
Native annotations for string functions (#4231)
- due to [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*:
Native annotations for string functions (#4234)
- due to [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*:
produce exp.Column for the RHS of <value> IN <name> (#4239)
- due to [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*:
make Explode a UDTF subclass (#4242)
### :sparkles: New Features
- [`163e943`](https://github.com/tobymao/sqlglot/commit/163e943cdaf449599640c198f69e73d2398eb323) - **tsql**: SPLIT_PART function and conversion to PARSENAME in tsql *(PR [#4211](https://github.com/tobymao/sqlglot/pull/4211) by [@daihuynh](https://github.com/daihuynh))*
- [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*
- [`3f6ba3e`](https://github.com/tobymao/sqlglot/commit/3f6ba3e69c9ba92429d2b3b00cac33f45518aa56) - **clickhouse**: Support varlen arrays for ARRAY JOIN *(PR [#4229](https://github.com/tobymao/sqlglot/pull/4229) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *addresses issue [#4227](https://github.com/tobymao/sqlglot/issues/4227) opened by [@brunorpinho](https://github.com/brunorpinho)*
- [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - **bigquery**: Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*
- [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - **bigquery**: Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*
- [`eeae25e`](https://github.com/tobymao/sqlglot/commit/eeae25e03a883671f9d5e514f9bd3021fb6c0d32) - support EXPLAIN in mysql *(PR [#4235](https://github.com/tobymao/sqlglot/pull/4235) by [@xiaoyu-meng-mxy](https://github.com/xiaoyu-meng-mxy))*
- [`06748d9`](https://github.com/tobymao/sqlglot/commit/06748d93ccd232528003c37fdda25ae8163f3c18) - **mysql**: add support for operation modifiers like HIGH_PRIORITY *(PR [#4238](https://github.com/tobymao/sqlglot/pull/4238) by [@georgesittas](https://github.com/georgesittas))*
- :arrow_lower_right: *addresses issue [#4236](https://github.com/tobymao/sqlglot/issues/4236) opened by [@asdfsx](https://github.com/asdfsx)*
### :bug: Bug Fixes
- [`dcdec95`](https://github.com/tobymao/sqlglot/commit/dcdec95f986426ae90469baca993b47ac390081b) - Make exp.Update a DML node *(PR [#4223](https://github.com/tobymao/sqlglot/pull/4223) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *fixes issue [#4221](https://github.com/tobymao/sqlglot/issues/4221) opened by [@rahul-ve](https://github.com/rahul-ve)*
- [`79caf51`](https://github.com/tobymao/sqlglot/commit/79caf519987718390a086bee19fdc89f6094496c) - **clickhouse**: rename BOOLEAN type to Bool fixes [#4230](https://github.com/tobymao/sqlglot/pull/4230) *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`b26a3f6`](https://github.com/tobymao/sqlglot/commit/b26a3f67b7113802ba1b4b3b211431e98258dc15) - satisfy mypy *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - **parser**: produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*
- :arrow_lower_right: *fixes issue [#4237](https://github.com/tobymao/sqlglot/issues/4237) opened by [@rustyconover](https://github.com/rustyconover)*
- [`daa6e78`](https://github.com/tobymao/sqlglot/commit/daa6e78e4b810eff826f995aa52f9e38197f1b7e) - **optimizer**: handle subquery predicate substitution correctly in de morgan's rule *(PR [#4240](https://github.com/tobymao/sqlglot/pull/4240) by [@georgesittas](https://github.com/georgesittas))*
- [`c0a8355`](https://github.com/tobymao/sqlglot/commit/c0a83556acffcd77521f69bf51503a07310f749d) - **parser**: parse a column reference for the RHS of the IN clause *(PR [#4241](https://github.com/tobymao/sqlglot/pull/4241) by [@georgesittas](https://github.com/georgesittas))*
### :recycle: Refactors
- [`0882f03`](https://github.com/tobymao/sqlglot/commit/0882f03d526f593b2d415e85b7d7a7c113721806) - Rename exp.RenameTable to exp.AlterRename *(PR [#4224](https://github.com/tobymao/sqlglot/pull/4224) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *addresses issue [#4222](https://github.com/tobymao/sqlglot/issues/4222) opened by [@s1101010110](https://github.com/s1101010110)*
- [`fd42b5c`](https://github.com/tobymao/sqlglot/commit/fd42b5cdaf9421abb11e71d82726536af09e3ae3) - Simplify PARSENAME <-> SPLIT_PART transpilation *(PR [#4225](https://github.com/tobymao/sqlglot/pull/4225) by [@VaggelisD](https://github.com/VaggelisD))*
- [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*
## [v25.24.5] - 2024-10-08
### :sparkles: New Features
- [`22a1684`](https://github.com/tobymao/sqlglot/commit/22a16848d80a2fa6d310f99d21f7d81f90eb9440) - **bigquery**: Native annotations for more math functions *(PR [#4212](https://github.com/tobymao/sqlglot/pull/4212) by [@VaggelisD](https://github.com/VaggelisD))*
- [`354cfff`](https://github.com/tobymao/sqlglot/commit/354cfff13ab30d01c6123fca74eed0669d238aa0) - add builder methods to exp.Update and add with_ arg to exp.update *(PR [#4217](https://github.com/tobymao/sqlglot/pull/4217) by [@brdbry](https://github.com/brdbry))*
### :bug: Bug Fixes
- [`2c513b7`](https://github.com/tobymao/sqlglot/commit/2c513b71c7d4b1ff5c7c4e12d6c38694210b1a12) - Attach CTE comments before commas *(PR [#4218](https://github.com/tobymao/sqlglot/pull/4218) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *fixes issue [#4216](https://github.com/tobymao/sqlglot/issues/4216) opened by [@ajfriend](https://github.com/ajfriend)*
## [v25.24.4] - 2024-10-04 ## [v25.24.4] - 2024-10-04
### :bug: Bug Fixes ### :bug: Bug Fixes
- [`484df7d`](https://github.com/tobymao/sqlglot/commit/484df7d50df5cb314943e1810db18a7d7d5bb3eb) - tsql union with limit *(commit by [@tobymao](https://github.com/tobymao))* - [`484df7d`](https://github.com/tobymao/sqlglot/commit/484df7d50df5cb314943e1810db18a7d7d5bb3eb) - tsql union with limit *(commit by [@tobymao](https://github.com/tobymao))*
@ -4969,3 +5035,6 @@ Changelog
[v25.24.2]: https://github.com/tobymao/sqlglot/compare/v25.24.1...v25.24.2 [v25.24.2]: https://github.com/tobymao/sqlglot/compare/v25.24.1...v25.24.2
[v25.24.3]: https://github.com/tobymao/sqlglot/compare/v25.24.2...v25.24.3 [v25.24.3]: https://github.com/tobymao/sqlglot/compare/v25.24.2...v25.24.3
[v25.24.4]: https://github.com/tobymao/sqlglot/compare/v25.24.3...v25.24.4 [v25.24.4]: https://github.com/tobymao/sqlglot/compare/v25.24.3...v25.24.4
[v25.24.5]: https://github.com/tobymao/sqlglot/compare/v25.24.4...v25.24.5
[v25.25.0]: https://github.com/tobymao/sqlglot/compare/v25.24.5...v25.25.0
[v25.25.1]: https://github.com/tobymao/sqlglot/compare/v25.25.0...v25.25.1

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -321,7 +321,24 @@ class BigQuery(Dialect):
expr_type: lambda self, e: _annotate_math_functions(self, e) expr_type: lambda self, e: _annotate_math_functions(self, e)
for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round) for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round)
}, },
**{
expr_type: lambda self, e: self._annotate_by_args(e, "this")
for expr_type in (
exp.Left,
exp.Right,
exp.Lower,
exp.Upper,
exp.Pad,
exp.Trim,
exp.RegexpExtract,
exp.RegexpReplace,
exp.Repeat,
exp.Substring,
)
},
exp.Concat: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"), exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True),
} }
def normalize_identifier(self, expression: E) -> E: def normalize_identifier(self, expression: E) -> E:
@ -716,6 +733,7 @@ class BigQuery(Dialect):
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"), exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"), exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Levenshtein: rename_func("EDIT_DISTANCE"),
exp.Max: max_or_greatest, exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"), exp.MD5Digest: rename_func("MD5"),

View file

@ -603,6 +603,12 @@ class ClickHouse(Dialect):
if join: if join:
join.set("global", join.args.pop("method", None)) join.set("global", join.args.pop("method", None))
# tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table`
# https://clickhouse.com/docs/en/sql-reference/statements/select/array-join
if join.kind == "ARRAY":
for table in join.find_all(exp.Table):
table.replace(table.to_column())
return join return join
def _parse_function( def _parse_function(
@ -627,15 +633,18 @@ class ClickHouse(Dialect):
) )
if parts: if parts:
params = self._parse_func_params(func) anon_func: exp.Anonymous = t.cast(exp.Anonymous, func)
params = self._parse_func_params(anon_func)
kwargs = { kwargs = {
"this": func.this, "this": anon_func.this,
"expressions": func.expressions, "expressions": anon_func.expressions,
} }
if parts[1]: if parts[1]:
kwargs["parts"] = parts kwargs["parts"] = parts
exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc exp_class: t.Type[exp.Expression] = (
exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
)
else: else:
exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc
@ -825,6 +834,7 @@ class ClickHouse(Dialect):
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
**STRING_TYPE_MAPPING, **STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BOOLEAN: "Bool",
exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATE32: "Date32", exp.DataType.Type.DATE32: "Date32",
exp.DataType.Type.DATETIME: "DateTime", exp.DataType.Type.DATETIME: "DateTime",

View file

@ -588,6 +588,7 @@ class Dialect(metaclass=_Dialect):
exp.Stddev, exp.Stddev,
exp.StddevPop, exp.StddevPop,
exp.StddevSamp, exp.StddevSamp,
exp.ToDouble,
exp.Variance, exp.Variance,
exp.VariancePop, exp.VariancePop,
}, },
@ -1697,3 +1698,18 @@ def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract:
expression=seq_get(args, 1), expression=seq_get(args, 1),
group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
) )
def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)

View file

@ -35,6 +35,7 @@ from sqlglot.dialects.dialect import (
unit_to_str, unit_to_str,
sha256_sql, sha256_sql,
build_regexp_extract, build_regexp_extract,
explode_to_unnest_sql,
) )
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
@ -538,6 +539,7 @@ class DuckDB(Dialect):
exp.JSONExtract: _arrow_json_extract_sql, exp.JSONExtract: _arrow_json_extract_sql,
exp.JSONExtractScalar: _arrow_json_extract_sql, exp.JSONExtractScalar: _arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql, exp.JSONFormat: _json_format_sql,
exp.Lateral: explode_to_unnest_sql,
exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),

View file

@ -333,6 +333,9 @@ class Hive(Dialect):
"TRANSFORM": lambda self: self._parse_transform(), "TRANSFORM": lambda self: self._parse_transform(),
} }
NO_PAREN_FUNCTIONS = parser.Parser.NO_PAREN_FUNCTIONS.copy()
NO_PAREN_FUNCTIONS.pop(TokenType.CURRENT_TIME)
PROPERTY_PARSERS = { PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, **parser.Parser.PROPERTY_PARSERS,
"SERDEPROPERTIES": lambda self: exp.SerdeProperties( "SERDEPROPERTIES": lambda self: exp.SerdeProperties(

View file

@ -187,6 +187,9 @@ class MySQL(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"CHARSET": TokenType.CHARACTER_SET, "CHARSET": TokenType.CHARACTER_SET,
# The DESCRIBE and EXPLAIN statements are synonyms.
# https://dev.mysql.com/doc/refman/8.4/en/explain.html
"EXPLAIN": TokenType.DESCRIBE,
"FORCE": TokenType.FORCE, "FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE, "IGNORE": TokenType.IGNORE,
"KEY": TokenType.KEY, "KEY": TokenType.KEY,
@ -453,6 +456,17 @@ class MySQL(Dialect):
TokenType.SET, TokenType.SET,
} }
# SELECT [ ALL | DISTINCT | DISTINCTROW ] [ <OPERATION_MODIFIERS> ]
OPERATION_MODIFIERS = {
"HIGH_PRIORITY",
"STRAIGHT_JOIN",
"SQL_SMALL_RESULT",
"SQL_BIG_RESULT",
"SQL_BUFFER_RESULT",
"SQL_NO_CACHE",
"SQL_CALC_FOUND_ROWS",
}
LOG_DEFAULTS_TO_LN = True LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True STRING_ALIASES = True
VALUES_FOLLOWED_BY_PAREN = False VALUES_FOLLOWED_BY_PAREN = False

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.parser import OPTIONS_TYPE, build_coalesce from sqlglot.parser import OPTIONS_TYPE, build_coalesce
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
from sqlglot.errors import ParseError
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot._typing import E from sqlglot._typing import E
@ -205,6 +206,57 @@ class Oracle(Dialect):
) )
def _parse_hint(self) -> t.Optional[exp.Hint]: def _parse_hint(self) -> t.Optional[exp.Hint]:
start_index = self._index
should_fallback_to_string = False
if not self._match(TokenType.HINT):
return None
hints = []
try:
for hint in iter(
lambda: self._parse_csv(
lambda: self._parse_hint_function_call() or self._parse_var(upper=True),
),
[],
):
hints.extend(hint)
except ParseError:
should_fallback_to_string = True
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
should_fallback_to_string = True
if should_fallback_to_string:
self._retreat(start_index)
return self._parse_hint_fallback_to_string()
return self.expression(exp.Hint, expressions=hints)
def _parse_hint_function_call(self) -> t.Optional[exp.Expression]:
if not self._curr or not self._next or self._next.token_type != TokenType.L_PAREN:
return None
this = self._curr.text
self._advance(2)
args = self._parse_hint_args()
this = self.expression(exp.Anonymous, this=this, expressions=args)
self._match_r_paren(this)
return this
def _parse_hint_args(self):
args = []
result = self._parse_var()
while result:
args.append(result)
result = self._parse_var()
return args
def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT): if self._match(TokenType.HINT):
start = self._curr start = self._curr
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
@ -271,6 +323,7 @@ class Oracle(Dialect):
LAST_DAY_SUPPORTS_DATE_PART = False LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True SUPPORTS_SELECT_INTO = True
TZ_TO_WITH_TIME_ZONE = True TZ_TO_WITH_TIME_ZONE = True
QUERY_HINT_SEP = " "
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
@ -370,3 +423,15 @@ class Oracle(Dialect):
return f"{self.seg(into)} {self.sql(expression, 'this')}" return f"{self.seg(into)} {self.sql(expression, 'this')}"
return f"{self.seg(into)} {self.expressions(expression)}" return f"{self.seg(into)} {self.expressions(expression)}"
def hint_sql(self, expression: exp.Hint) -> str:
expressions = []
for expression in expression.expressions:
if isinstance(expression, exp.Anonymous):
formatted_args = self.format_args(*expression.expressions, sep=" ")
expressions.append(f"{self.sql(expression, 'this')}({formatted_args})")
else:
expressions.append(self.sql(expression))
return f" /*+ {self.expressions(sqls=expressions, sep=self.QUERY_HINT_SEP).strip()} */"

View file

@ -30,6 +30,7 @@ from sqlglot.dialects.dialect import (
unit_to_str, unit_to_str,
sequence_sql, sequence_sql,
build_regexp_extract, build_regexp_extract,
explode_to_unnest_sql,
) )
from sqlglot.dialects.hive import Hive from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.mysql import MySQL
@ -40,21 +41,6 @@ from sqlglot.transforms import unqualify_columns
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub] DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub]
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)
def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str: def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)" regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
@ -340,16 +326,17 @@ class Presto(Dialect):
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY", exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR", exp.DataType.Type.BIT: "BOOLEAN",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.DATETIME64: "TIMESTAMP", exp.DataType.Type.DATETIME64: "TIMESTAMP",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG", exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.STRUCT: "ROW",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TIMETZ: "TIME",
} }
TRANSFORMS = { TRANSFORMS = {
@ -400,9 +387,6 @@ class Presto(Dialect):
exp.GenerateSeries: sequence_sql, exp.GenerateSeries: sequence_sql,
exp.GenerateDateArray: sequence_sql, exp.GenerateDateArray: sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.If: if_sql(), exp.If: if_sql(),
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql, exp.Initcap: _initcap_sql,
@ -410,7 +394,7 @@ class Presto(Dialect):
exp.Last: _first_last_sql, exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql, exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql, exp.Lateral: explode_to_unnest_sql,
exp.Left: left_to_substring_sql, exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalAnd: rename_func("BOOL_AND"),
@ -694,3 +678,10 @@ class Presto(Dialect):
expr = "".join(segments) expr = "".join(segments)
return f"{this}{expr}" return f"{this}{expr}"
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
return self.func(
"ARRAY_JOIN",
self.func("ARRAY_AGG", expression.this),
expression.args.get("separator"),
)

View file

@ -41,15 +41,23 @@ def _build_datetime(
if isinstance(value, exp.Literal): if isinstance(value, exp.Literal):
# Converts calls like `TO_TIME('01:02:03')` into casts # Converts calls like `TO_TIME('01:02:03')` into casts
if len(args) == 1 and value.is_string and not int_value: if len(args) == 1 and value.is_string and not int_value:
return exp.cast(value, kind) return (
exp.TryCast(this=value, to=exp.DataType.build(kind))
if safe
else exp.cast(value, kind)
)
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
# cases so we can transpile them, since they're relatively common # cases so we can transpile them, since they're relatively common
if kind == exp.DataType.Type.TIMESTAMP: if kind == exp.DataType.Type.TIMESTAMP:
if int_value: if int_value and not safe:
# TRY_TO_TIMESTAMP('integer') is not parsed into exp.UnixToTime as
# it's not easily transpilable
return exp.UnixToTime(this=value, scale=seq_get(args, 1)) return exp.UnixToTime(this=value, scale=seq_get(args, 1))
if not is_float(value.this): if not is_float(value.this):
return build_formatted_time(exp.StrToTime, "snowflake")(args) expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
expr.set("safe", safe)
return expr
if kind == exp.DataType.Type.DATE and not int_value: if kind == exp.DataType.Type.DATE and not int_value:
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args) formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
@ -345,6 +353,9 @@ class Snowflake(Dialect):
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts, "TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True), "TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TRY_TO_TIMESTAMP": _build_datetime(
"TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber( "TO_NUMBER": lambda args: exp.ToNumber(
this=seq_get(args, 0), this=seq_get(args, 0),
@ -384,7 +395,6 @@ class Snowflake(Dialect):
expressions=self._parse_csv(self._parse_id_var), expressions=self._parse_csv(self._parse_id_var),
unset=True, unset=True,
), ),
"SWAP": lambda self: self._parse_alter_table_swap(),
} }
STATEMENT_PARSERS = { STATEMENT_PARSERS = {
@ -654,10 +664,6 @@ class Snowflake(Dialect):
}, },
) )
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
def _parse_location_property(self) -> exp.LocationProperty: def _parse_location_property(self) -> exp.LocationProperty:
self._match(TokenType.EQ) self._match(TokenType.EQ)
return self.expression(exp.LocationProperty, this=self._parse_location_path()) return self.expression(exp.LocationProperty, this=self._parse_location_path())
@ -828,7 +834,6 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func( exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position") "POSITION", e.args.get("substr"), e.this, e.args.get("position")
), ),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Stuff: rename_func("INSERT"), exp.Stuff: rename_func("INSERT"),
exp.TimeAdd: date_delta_sql("TIMEADD"), exp.TimeAdd: date_delta_sql("TIMEADD"),
exp.TimestampDiff: lambda self, e: self.func( exp.TimestampDiff: lambda self, e: self.func(
@ -842,6 +847,7 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToArray: rename_func("TO_ARRAY"), exp.ToArray: rename_func("TO_ARRAY"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToDouble: rename_func("TO_DOUBLE"),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: lambda self, e: self.func( exp.TsOrDsToDate: lambda self, e: self.func(
@ -1036,10 +1042,6 @@ class Snowflake(Dialect):
increment = f" INCREMENT {increment}" if increment else "" increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}" return f"AUTOINCREMENT{start}{increment}"
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
def cluster_sql(self, expression: exp.Cluster) -> str: def cluster_sql(self, expression: exp.Cluster) -> str:
return f"CLUSTER BY ({self.expressions(expression, flat=True)})" return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
@ -1074,3 +1076,9 @@ class Snowflake(Dialect):
tag = f" TAG {tag}" if tag else "" tag = f" TAG {tag}" if tag else ""
return f"SET{exprs}{file_format}{copy_options}{tag}" return f"SET{exprs}{file_format}{copy_options}{tag}"
def strtotime_sql(self, expression: exp.StrToTime):
safe_prefix = "TRY_" if expression.args.get("safe") else ""
return self.func(
f"{safe_prefix}TO_TIMESTAMP", expression.this, self.format_time(expression)
)

View file

@ -10,11 +10,15 @@ class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True LOG_BASE_FIRST = True
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
class Parser(Presto.Parser): class Parser(Presto.Parser):
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
**Presto.Parser.FUNCTION_PARSERS, **Presto.Parser.FUNCTION_PARSERS,
"TRIM": lambda self: self._parse_trim(), "TRIM": lambda self: self._parse_trim(),
"JSON_QUERY": lambda self: self._parse_json_query(), "JSON_QUERY": lambda self: self._parse_json_query(),
"LISTAGG": lambda self: self._parse_string_agg(),
} }
JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = { JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
@ -65,5 +69,14 @@ class Trino(Presto):
return self.func("JSON_QUERY", expression.this, json_path + option) return self.func("JSON_QUERY", expression.this, json_path + option)
class Tokenizer(Presto.Tokenizer): def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
HEX_STRINGS = [("X'", "'")] this = expression.this
separator = expression.args.get("separator") or exp.Literal.string(",")
if isinstance(this, exp.Order):
if this.this:
this = this.this.pop()
return f"LISTAGG({self.format_args(this, separator)}) WITHIN GROUP ({self.sql(expression.this).lstrip()})"
return super().groupconcat_sql(expression)

View file

@ -324,6 +324,25 @@ def _build_with_arg_as_text(
return _parse return _parse
# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16
def _build_parsename(args: t.List) -> exp.SplitPart | exp.Anonymous:
# PARSENAME(...) will be stored into exp.SplitPart if:
# - All args are literals
# - The part index (2nd arg) is <= 4 (max valid value, otherwise TSQL returns NULL)
if len(args) == 2 and all(isinstance(arg, exp.Literal) for arg in args):
this = args[0]
part_index = args[1]
split_count = len(this.name.split("."))
if split_count <= 4:
return exp.SplitPart(
this=this,
delimiter=exp.Literal.string("."),
part_index=exp.Literal.number(split_count + 1 - part_index.to_py()),
)
return exp.Anonymous(this="PARSENAME", expressions=args)
def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract: def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract:
if len(args) == 1: if len(args) == 1:
# The default value for path is '$'. As a result, if you don't provide a # The default value for path is '$'. As a result, if you don't provide a
@ -543,6 +562,7 @@ class TSQL(Dialect):
"LEN": _build_with_arg_as_text(exp.Length), "LEN": _build_with_arg_as_text(exp.Length),
"LEFT": _build_with_arg_as_text(exp.Left), "LEFT": _build_with_arg_as_text(exp.Left),
"RIGHT": _build_with_arg_as_text(exp.Right), "RIGHT": _build_with_arg_as_text(exp.Right),
"PARSENAME": _build_parsename,
"REPLICATE": exp.Repeat.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@ -554,6 +574,10 @@ class TSQL(Dialect):
JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
PROCEDURE_OPTIONS = dict.fromkeys(
("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple()
)
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE, TokenType.TABLE,
*parser.Parser.TYPE_TOKENS, *parser.Parser.TYPE_TOKENS,
@ -699,7 +723,11 @@ class TSQL(Dialect):
): ):
return this return this
if not self._match(TokenType.WITH, advance=False):
expressions = self._parse_csv(self._parse_function_parameter) expressions = self._parse_csv(self._parse_function_parameter)
else:
expressions = None
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_id_var( def _parse_id_var(
@ -954,6 +982,27 @@ class TSQL(Dialect):
self.unsupported("LATERAL clause is not supported.") self.unsupported("LATERAL clause is not supported.")
return "LATERAL" return "LATERAL"
def splitpart_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str:
this = expression.this
split_count = len(this.name.split("."))
delimiter = expression.args.get("delimiter")
part_index = expression.args.get("part_index")
if (
not all(isinstance(arg, exp.Literal) for arg in (this, delimiter, part_index))
or (delimiter and delimiter.name != ".")
or not part_index
or split_count > 4
):
self.unsupported(
"SPLIT_PART can be transpiled to PARSENAME only for '.' delimiter and literal values"
)
return ""
return self.func(
"PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py())
)
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano") nano = expression.args.get("nano")
if nano is not None: if nano is not None:
@ -1166,7 +1215,7 @@ class TSQL(Dialect):
def alter_sql(self, expression: exp.Alter) -> str: def alter_sql(self, expression: exp.Alter) -> str:
action = seq_get(expression.args.get("actions") or [], 0) action = seq_get(expression.args.get("actions") or [], 0)
if isinstance(action, exp.RenameTable): if isinstance(action, exp.AlterRename):
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'" return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
return super().alter_sql(expression) return super().alter_sql(expression)

View file

@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_list from sqlglot.helper import seq_get
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType from sqlglot.dialects.dialect import DialectType
@ -185,7 +185,7 @@ class ChangeDistiller:
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()} matching_set = self._compute_matching_set() | set(pre_matched_nodes.items())
return self._generate_edit_script(matching_set, delta_only) return self._generate_edit_script(matching_set, delta_only)
def _generate_edit_script( def _generate_edit_script(
@ -201,6 +201,7 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set: for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id] source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id] target_node = self._target_index[kept_target_node_id]
if ( if (
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
or source_node == target_node or source_node == target_node
@ -208,7 +209,13 @@ class ChangeDistiller:
edit_script.extend( edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set) self._generate_move_edits(source_node, target_node, matching_set)
) )
if not delta_only:
source_non_expression_leaves = dict(_get_non_expression_leaves(source_node))
target_non_expression_leaves = dict(_get_non_expression_leaves(target_node))
if source_non_expression_leaves != target_non_expression_leaves:
edit_script.append(Update(source_node, target_node))
elif not delta_only:
edit_script.append(Keep(source_node, target_node)) edit_script.append(Keep(source_node, target_node))
else: else:
edit_script.append(Update(source_node, target_node)) edit_script.append(Update(source_node, target_node))
@ -246,8 +253,8 @@ class ChangeDistiller:
source_node = self._source_index[source_node_id] source_node = self._source_index[source_node_id]
target_node = self._target_index[target_node_id] target_node = self._target_index[target_node_id]
if _is_same_type(source_node, target_node): if _is_same_type(source_node, target_node):
source_leaf_ids = {id(l) for l in _get_leaves(source_node)} source_leaf_ids = {id(l) for l in _get_expression_leaves(source_node)}
target_leaf_ids = {id(l) for l in _get_leaves(target_node)} target_leaf_ids = {id(l) for l in _get_expression_leaves(target_node)}
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num: if max_leaves_num:
@ -277,10 +284,10 @@ class ChangeDistiller:
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = [] candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source)) source_expression_leaves = list(_get_expression_leaves(self._source))
target_leaves = list(_get_leaves(self._target)) target_expression_leaves = list(_get_expression_leaves(self._target))
for source_leaf in source_leaves: for source_leaf in source_expression_leaves:
for target_leaf in target_leaves: for target_leaf in target_expression_leaves:
if _is_same_type(source_leaf, target_leaf): if _is_same_type(source_leaf, target_leaf):
similarity_score = self._dice_coefficient(source_leaf, target_leaf) similarity_score = self._dice_coefficient(source_leaf, target_leaf)
if similarity_score >= self.f: if similarity_score >= self.f:
@ -338,18 +345,28 @@ class ChangeDistiller:
return bigram_histo return bigram_histo
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False has_child_exprs = False
for node in expression.iter_expressions(): for node in expression.iter_expressions():
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES): if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
has_child_exprs = True has_child_exprs = True
yield from _get_leaves(node) yield from _get_expression_leaves(node)
if not has_child_exprs: if not has_child_exprs:
yield expression yield expression
def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]:
for arg, value in expression.args.items():
if isinstance(value, exp.Expression) or (
isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression)
):
continue
yield (arg, value)
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target): if type(source) is type(target):
if isinstance(source, exp.Join): if isinstance(source, exp.Join):
@ -372,16 +389,12 @@ def _parent_similarity_score(
return 1 + _parent_similarity_score(source.parent, target.parent) return 1 + _parent_similarity_score(source.parent, target.parent)
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: def _expression_only_args(expression: exp.Expression) -> t.Iterator[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = [] yield from (
if expression: arg
for a in expression.args.values(): for arg in expression.iter_expressions()
args.extend(ensure_list(a)) if not isinstance(arg, IGNORED_LEAF_EXPRESSION_TYPES)
return [ )
a
for a in args
if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
]
def _lcs( def _lcs(

View file

@ -404,9 +404,9 @@ class Expression(metaclass=_Expression):
def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]: def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]:
"""Yields the key and expression for all arguments, exploding list args.""" """Yields the key and expression for all arguments, exploding list args."""
# remove tuple when python 3.7 is deprecated # remove tuple when python 3.7 is deprecated
for vs in reversed(tuple(self.args.values())) if reverse else self.args.values(): for vs in reversed(tuple(self.args.values())) if reverse else self.args.values(): # type: ignore
if type(vs) is list: if type(vs) is list:
for v in reversed(vs) if reverse else vs: for v in reversed(vs) if reverse else vs: # type: ignore
if hasattr(v, "parent"): if hasattr(v, "parent"):
yield v yield v
else: else:
@ -1247,7 +1247,7 @@ class Query(Expression):
) )
def union( def union(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Union: ) -> Union:
""" """
Builds a UNION expression. Builds a UNION expression.
@ -1258,8 +1258,8 @@ class Query(Expression):
'SELECT * FROM foo UNION SELECT * FROM bla' 'SELECT * FROM foo UNION SELECT * FROM bla'
Args: Args:
expression: the SQL code string. expressions: the SQL code strings.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions. opts: other options to use to parse the input expressions.
@ -1267,10 +1267,10 @@ class Query(Expression):
Returns: Returns:
The new Union expression. The new Union expression.
""" """
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) return union(self, *expressions, distinct=distinct, dialect=dialect, **opts)
def intersect( def intersect(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Intersect: ) -> Intersect:
""" """
Builds an INTERSECT expression. Builds an INTERSECT expression.
@ -1281,8 +1281,8 @@ class Query(Expression):
'SELECT * FROM foo INTERSECT SELECT * FROM bla' 'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args: Args:
expression: the SQL code string. expressions: the SQL code strings.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions. opts: other options to use to parse the input expressions.
@ -1290,10 +1290,10 @@ class Query(Expression):
Returns: Returns:
The new Intersect expression. The new Intersect expression.
""" """
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) return intersect(self, *expressions, distinct=distinct, dialect=dialect, **opts)
def except_( def except_(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Except: ) -> Except:
""" """
Builds an EXCEPT expression. Builds an EXCEPT expression.
@ -1304,8 +1304,8 @@ class Query(Expression):
'SELECT * FROM foo EXCEPT SELECT * FROM bla' 'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args: Args:
expression: the SQL code string. expressions: the SQL code strings.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instance are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions. opts: other options to use to parse the input expressions.
@ -1313,7 +1313,7 @@ class Query(Expression):
Returns: Returns:
The new Except expression. The new Except expression.
""" """
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) return except_(self, *expressions, distinct=distinct, dialect=dialect, **opts)
class UDTF(DerivedTable): class UDTF(DerivedTable):
@ -1697,7 +1697,7 @@ class RenameColumn(Expression):
arg_types = {"this": True, "to": True, "exists": False} arg_types = {"this": True, "to": True, "exists": False}
class RenameTable(Expression): class AlterRename(Expression):
pass pass
@ -2400,6 +2400,7 @@ class Join(Expression):
"global": False, "global": False,
"hint": False, "hint": False,
"match_condition": False, # Snowflake "match_condition": False, # Snowflake
"expressions": False,
} }
@property @property
@ -2995,6 +2996,10 @@ class WithSystemVersioningProperty(Property):
} }
class WithProcedureOptions(Property):
arg_types = {"expressions": True}
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True} arg_types = {"expressions": True}
@ -3213,10 +3218,18 @@ class Table(Expression):
def to_column(self, copy: bool = True) -> Alias | Column | Dot: def to_column(self, copy: bool = True) -> Alias | Column | Dot:
parts = self.parts parts = self.parts
last_part = parts[-1]
if isinstance(last_part, Identifier):
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
else:
# This branch will be reached if a function or array is wrapped in a `Table`
col = last_part
alias = self.args.get("alias") alias = self.args.get("alias")
if alias: if alias:
col = alias_(col, alias.this, copy=copy) col = alias_(col, alias.this, copy=copy)
return col return col
@ -3278,7 +3291,7 @@ class Intersect(SetOperation):
pass pass
class Update(Expression): class Update(DML):
arg_types = { arg_types = {
"with": False, "with": False,
"this": False, "this": False,
@ -3526,6 +3539,7 @@ class Select(Query):
"distinct": False, "distinct": False,
"into": False, "into": False,
"from": False, "from": False,
"operation_modifiers": False,
**QUERY_MODIFIERS, **QUERY_MODIFIERS,
} }
@ -5184,6 +5198,14 @@ class ToNumber(Func):
} }
# https://docs.snowflake.com/en/sql-reference/functions/to_double
class ToDouble(Func):
arg_types = {
"this": True,
"format": False,
}
class Columns(Func): class Columns(Func):
arg_types = {"this": True, "unpack": False} arg_types = {"this": True, "unpack": False}
@ -5641,7 +5663,7 @@ class Exp(Func):
# https://docs.snowflake.com/en/sql-reference/functions/flatten # https://docs.snowflake.com/en/sql-reference/functions/flatten
class Explode(Func): class Explode(Func, UDTF):
arg_types = {"this": True, "expressions": False} arg_types = {"this": True, "expressions": False}
is_var_len_args = True is_var_len_args = True
@ -6248,6 +6270,11 @@ class Split(Func):
arg_types = {"this": True, "expression": True, "limit": False} arg_types = {"this": True, "expression": True, "limit": False}
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html
class SplitPart(Func):
arg_types = {"this": True, "delimiter": True, "part_index": True}
# Start may be omitted in the case of postgres # Start may be omitted in the case of postgres
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func): class Substring(Func):
@ -6857,26 +6884,37 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
return Paren(this=expression) if isinstance(expression, kind) else expression return Paren(this=expression) if isinstance(expression, kind) else expression
def _apply_set_operation(
*expressions: ExpOrStr,
set_operation: t.Type[S],
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> S:
return reduce(
lambda x, y: set_operation(this=x, expression=y, distinct=distinct),
(maybe_parse(e, dialect=dialect, copy=copy, **opts) for e in expressions),
)
def union( def union(
left: ExpOrStr, *expressions: ExpOrStr,
right: ExpOrStr,
distinct: bool = True, distinct: bool = True,
dialect: DialectType = None, dialect: DialectType = None,
copy: bool = True, copy: bool = True,
**opts, **opts,
) -> Union: ) -> Union:
""" """
Initializes a syntax tree from one UNION expression. Initializes a syntax tree for the `UNION` operation.
Example: Example:
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla' 'SELECT * FROM foo UNION SELECT * FROM bla'
Args: Args:
left: the SQL code string corresponding to the left-hand side. expressions: the SQL code strings, corresponding to the `UNION`'s operands.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instances are passed, they will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
copy: whether to copy the expression. copy: whether to copy the expression.
@ -6885,32 +6923,29 @@ def union(
Returns: Returns:
The new Union instance. The new Union instance.
""" """
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) assert len(expressions) >= 2, "At least two expressions are required by `union`."
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return _apply_set_operation(
*expressions, set_operation=Union, distinct=distinct, dialect=dialect, copy=copy, **opts
return Union(this=left, expression=right, distinct=distinct) )
def intersect( def intersect(
left: ExpOrStr, *expressions: ExpOrStr,
right: ExpOrStr,
distinct: bool = True, distinct: bool = True,
dialect: DialectType = None, dialect: DialectType = None,
copy: bool = True, copy: bool = True,
**opts, **opts,
) -> Intersect: ) -> Intersect:
""" """
Initializes a syntax tree from one INTERSECT expression. Initializes a syntax tree for the `INTERSECT` operation.
Example: Example:
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla' 'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args: Args:
left: the SQL code string corresponding to the left-hand side. expressions: the SQL code strings, corresponding to the `INTERSECT`'s operands.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instances are passed, they will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
copy: whether to copy the expression. copy: whether to copy the expression.
@ -6919,32 +6954,29 @@ def intersect(
Returns: Returns:
The new Intersect instance. The new Intersect instance.
""" """
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) assert len(expressions) >= 2, "At least two expressions are required by `intersect`."
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return _apply_set_operation(
*expressions, set_operation=Intersect, distinct=distinct, dialect=dialect, copy=copy, **opts
return Intersect(this=left, expression=right, distinct=distinct) )
def except_( def except_(
left: ExpOrStr, *expressions: ExpOrStr,
right: ExpOrStr,
distinct: bool = True, distinct: bool = True,
dialect: DialectType = None, dialect: DialectType = None,
copy: bool = True, copy: bool = True,
**opts, **opts,
) -> Except: ) -> Except:
""" """
Initializes a syntax tree from one EXCEPT expression. Initializes a syntax tree for the `EXCEPT` operation.
Example: Example:
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla' 'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args: Args:
left: the SQL code string corresponding to the left-hand side. expressions: the SQL code strings, corresponding to the `EXCEPT`'s operands.
If an `Expression` instance is passed, it will be used as-is. If `Expression` instances are passed, they will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true. distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression. dialect: the dialect used to parse the input expression.
copy: whether to copy the expression. copy: whether to copy the expression.
@ -6953,10 +6985,10 @@ def except_(
Returns: Returns:
The new Except instance. The new Except instance.
""" """
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts) assert len(expressions) >= 2, "At least two expressions are required by `except_`."
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts) return _apply_set_operation(
*expressions, set_operation=Except, distinct=distinct, dialect=dialect, copy=copy, **opts
return Except(this=left, expression=right, distinct=distinct) )
def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select: def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
@ -7410,15 +7442,9 @@ def to_interval(interval: str | Literal) -> Interval:
interval = interval.this interval = interval.this
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore interval = maybe_parse(f"INTERVAL {interval}")
assert isinstance(interval, Interval)
if not interval_parts: return interval
raise ValueError("Invalid interval string.")
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2).upper()),
)
def to_table( def to_table(
@ -7795,7 +7821,7 @@ def rename_table(
this=old_table, this=old_table,
kind="TABLE", kind="TABLE",
actions=[ actions=[
RenameTable(this=new_table), AlterRename(this=new_table),
], ],
) )

View file

@ -185,6 +185,7 @@ class Generator(metaclass=_Generator):
exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}", exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}",
exp.StreamingTableProperty: lambda *_: "STREAMING", exp.StreamingTableProperty: lambda *_: "STREAMING",
exp.StrictProperty: lambda *_: "STRICT", exp.StrictProperty: lambda *_: "STRICT",
exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}",
exp.TemporaryProperty: lambda *_: "TEMPORARY", exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})", exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
@ -200,6 +201,7 @@ class Generator(metaclass=_Generator):
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}", exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE", exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}",
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}", exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}", exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
} }
@ -564,6 +566,7 @@ class Generator(metaclass=_Generator):
exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA,
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA, exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA, exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
} }
@ -2144,6 +2147,10 @@ class Generator(metaclass=_Generator):
this = expression.this this = expression.this
this_sql = self.sql(this) this_sql = self.sql(this)
exprs = self.expressions(expression)
if exprs:
this_sql = f"{this_sql},{self.seg(exprs)}"
if on_sql: if on_sql:
on_sql = self.indent(on_sql, skip_first=True) on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " " space = self.seg(" " * self.pad) if self.pretty else " "
@ -2510,13 +2517,16 @@ class Generator(metaclass=_Generator):
) )
kind = "" kind = ""
operation_modifiers = self.expressions(expression, key="operation_modifiers", sep=" ")
operation_modifiers = f"{self.sep()}{operation_modifiers}" if operation_modifiers else ""
# We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata # We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
# are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first. # are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}" top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
expressions = f"{self.sep()}{expressions}" if expressions else expressions expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers( sql = self.query_modifiers(
expression, expression,
f"SELECT{top_distinct}{kind}{expressions}", f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}",
self.sql(expression, "into", comment=False), self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False), self.sql(expression, "from", comment=False),
) )
@ -3225,12 +3235,12 @@ class Generator(metaclass=_Generator):
expressions = f"({expressions})" if expressions else "" expressions = f"({expressions})" if expressions else ""
return f"ALTER{compound} SORTKEY {this or expressions}" return f"ALTER{compound} SORTKEY {this or expressions}"
def renametable_sql(self, expression: exp.RenameTable) -> str: def alterrename_sql(self, expression: exp.AlterRename) -> str:
if not self.RENAME_TABLE_WITH_DB: if not self.RENAME_TABLE_WITH_DB:
# Remove db from tables # Remove db from tables
expression = expression.transform( expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
).assert_is(exp.RenameTable) ).assert_is(exp.AlterRename)
this = self.sql(expression, "this") this = self.sql(expression, "this")
return f"RENAME TO {this}" return f"RENAME TO {this}"
@ -3508,13 +3518,15 @@ class Generator(metaclass=_Generator):
name = self.normalize_func(name) if normalize else name name = self.normalize_func(name) if normalize else name
return f"{name}{prefix}{self.format_args(*args)}{suffix}" return f"{name}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: def format_args(self, *args: t.Optional[str | exp.Expression], sep: str = ", ") -> str:
arg_sqls = tuple( arg_sqls = tuple(
self.sql(arg) for arg in args if arg is not None and not isinstance(arg, bool) self.sql(arg) for arg in args if arg is not None and not isinstance(arg, bool)
) )
if self.pretty and self.too_wide(arg_sqls): if self.pretty and self.too_wide(arg_sqls):
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) return self.indent(
return ", ".join(arg_sqls) "\n" + f"{sep.strip()}\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True
)
return sep.join(arg_sqls)
def too_wide(self, args: t.Iterable) -> bool: def too_wide(self, args: t.Iterable) -> bool:
return sum(len(arg) for arg in args) > self.max_text_width return sum(len(arg) for arg in args) > self.max_text_width
@ -3612,7 +3624,7 @@ class Generator(metaclass=_Generator):
expressions = ( expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
) )
return f"{this}{expressions}" return f"{this}{expressions}" if expressions.strip() != "" else this
def joinhint_sql(self, expression: exp.JoinHint) -> str: def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -4243,7 +4255,7 @@ class Generator(metaclass=_Generator):
else: else:
rhs = self.expressions(expression) rhs = self.expressions(expression)
return self.func(name, expression.this, rhs) return self.func(name, expression.this, rhs or None)
def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str: def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
if self.SUPPORTS_CONVERT_TIMEZONE: if self.SUPPORTS_CONVERT_TIMEZONE:
@ -4418,3 +4430,7 @@ class Generator(metaclass=_Generator):
for_sql = f" FOR {for_sql}" if for_sql else "" for_sql = f" FOR {for_sql}" if for_sql else ""
return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})" return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})"
@unsupported_args("format")
def todouble_sql(self, expression: exp.ToDouble) -> str:
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))

View file

@ -56,6 +56,10 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
def ensure_list(value: t.Collection[T]) -> t.List[T]: ... def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
@t.overload
def ensure_list(value: None) -> t.List: ...
@t.overload @t.overload
def ensure_list(value: T) -> t.List[T]: ... def ensure_list(value: T) -> t.List[T]: ...

View file

@ -287,15 +287,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce( def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType: ) -> exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1 type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2 type2_value = type2.this if isinstance(type2, exp.DataType) else type2
# We propagate the UNKNOWN type upwards if found # We propagate the UNKNOWN type upwards if found
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.build("unknown") return exp.DataType.Type.UNKNOWN
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value return t.cast(
exp.DataType.Type,
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
)
def _annotate_binary(self, expression: B) -> B: def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression) self._annotate_args(expression)

View file

@ -1,11 +1,18 @@
from __future__ import annotations
import itertools import itertools
import typing as t
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.helper import find_new_name from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.scope import Scope, build_scope
if t.TYPE_CHECKING:
ExistingCTEsMapping = t.Dict[exp.Expression, str]
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
def eliminate_subqueries(expression): def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
""" """
Rewrite derived tables as CTES, deduplicating if possible. Rewrite derived tables as CTES, deduplicating if possible.
@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
# Map of alias->Scope|Table # Map of alias->Scope|Table
# These are all aliases that are already used in the expression. # These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names. # We don't want to create new CTEs that conflict with these names.
taken = {} taken: TakenNameMapping = {}
# All CTE aliases in the root scope are taken # All CTE aliases in the root scope are taken
for scope in root.cte_scopes: for scope in root.cte_scopes:
@ -56,7 +63,7 @@ def eliminate_subqueries(expression):
# Map of Expression->alias # Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication. # Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {} existing_ctes: ExistingCTEsMapping = {}
with_ = root.expression.args.get("with") with_ = root.expression.args.get("with")
recursive = False recursive = False
@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
return expression return expression
def _eliminate(scope, existing_ctes, taken): def _eliminate(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
if scope.is_derived_table: if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken) return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte: if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken) return _eliminate_cte(scope, existing_ctes, taken)
return None
def _eliminate_derived_table(scope, existing_ctes, taken):
def _eliminate_derived_table(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
# This makes sure that we don't: # This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery # - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery # - eliminate a lateral correlated subquery
@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
return cte return cte
def _eliminate_cte(scope, existing_ctes, taken): def _eliminate_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
parent = scope.expression.parent parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken) name, cte = _new_cte(scope, existing_ctes, taken)
@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
return cte return cte
def _new_cte(scope, existing_ctes, taken): def _new_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Tuple[str, t.Optional[exp.Expression]]:
""" """
Returns: Returns:
tuple of (name, cte) tuple of (name, cte)

View file

@ -1,11 +1,20 @@
from __future__ import annotations
import typing as t
from collections import defaultdict from collections import defaultdict
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.helper import find_new_name from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.optimizer.scope import Scope, traverse_scope
if t.TYPE_CHECKING:
from sqlglot._typing import E
def merge_subqueries(expression, leave_tables_isolated=False): FromOrJoin = t.Union[exp.From, exp.Join]
def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
""" """
Rewrite sqlglot AST to merge derived tables into the outer query. Rewrite sqlglot AST to merge derived tables into the outer query.
@ -58,7 +67,7 @@ SAFE_TO_REPLACE_UNWRAPPED = (
) )
def merge_ctes(expression, leave_tables_isolated=False): def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
scopes = traverse_scope(expression) scopes = traverse_scope(expression)
# All places where we select from CTEs. # All places where we select from CTEs.
@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
return expression return expression
def merge_derived_tables(expression, leave_tables_isolated=False): def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
for outer_scope in traverse_scope(expression): for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables: for subquery in outer_scope.derived_tables:
from_or_join = subquery.find_ancestor(exp.From, exp.Join) from_or_join = subquery.find_ancestor(exp.From, exp.Join)
@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
return expression return expression
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): def _mergeable(
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
) -> bool:
""" """
Return True if `inner_select` can be merged into outer query. Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
""" """
inner_select = inner_scope.expression.unnest() inner_select = inner_scope.expression.unnest()
@ -195,7 +198,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not outer_scope.expression.is_star and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select) and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from") and inner_select.args.get("from") is not None
and not outer_scope.pivots and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
@ -218,19 +221,17 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
) )
def _rename_inner_sources(outer_scope, inner_scope, alias): def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
""" """
Renames any sources in the inner query that conflict with names in the outer query. Renames any sources in the inner query that conflict with names in the outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
""" """
taken = set(outer_scope.selected_sources) inner_taken = set(inner_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources)) outer_taken = set(outer_scope.selected_sources)
conflicts = outer_taken.intersection(inner_taken)
conflicts -= {alias} conflicts -= {alias}
taken = outer_taken.union(inner_taken)
for conflict in conflicts: for conflict in conflicts:
new_name = find_new_name(taken, conflict) new_name = find_new_name(taken, conflict)
@ -250,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name) inner_scope.rename_source(conflict, new_name)
def _merge_from(outer_scope, inner_scope, node_to_replace, alias): def _merge_from(
outer_scope: Scope,
inner_scope: Scope,
node_to_replace: t.Union[exp.Subquery, exp.Table],
alias: str,
) -> None:
""" """
Merge FROM clause of inner query into outer query. Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
""" """
new_subquery = inner_scope.expression.args["from"].this new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins")) new_subquery.set("joins", node_to_replace.args.get("joins"))
@ -274,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
) )
def _merge_joins(outer_scope, inner_scope, from_or_join): def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
""" """
Merge JOIN clauses of inner query into outer query. Merge JOIN clauses of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
""" """
new_joins = [] new_joins = []
@ -304,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
outer_scope.expression.set("joins", outer_joins) outer_scope.expression.set("joins", outer_joins)
def _merge_expressions(outer_scope, inner_scope, alias): def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
""" """
Merge projections of inner query into outer query. Merge projections of inner query into outer query.
@ -338,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
column.replace(expression.copy()) column.replace(expression.copy())
def _merge_where(outer_scope, inner_scope, from_or_join): def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
""" """
Merge WHERE clause of inner query into outer query. Merge WHERE clause of inner query into outer query.
@ -357,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause # Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined # if it only has columns that are already joined
from_ = expression.args.get("from") from_ = expression.args.get("from")
sources = {from_.alias_or_name} if from_ else {} sources = {from_.alias_or_name} if from_ else set()
for join in expression.args["joins"]: for join in expression.args["joins"]:
source = join.alias_or_name source = join.alias_or_name
@ -373,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
expression.where(where.this, copy=False) expression.where(where.this, copy=False)
def _merge_order(outer_scope, inner_scope): def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
""" """
Merge ORDER clause of inner query into outer query. Merge ORDER clause of inner query into outer query.
@ -393,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order")) outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _merge_hints(outer_scope, inner_scope): def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
inner_scope_hint = inner_scope.expression.args.get("hint") inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint: if not inner_scope_hint:
return return
@ -405,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
outer_scope.expression.set("hint", inner_scope_hint) outer_scope.expression.set("hint", inner_scope_hint)
def _pop_cte(inner_scope): def _pop_cte(inner_scope: Scope) -> None:
""" """
Remove CTE from the AST. Remove CTE from the AST.

View file

@ -27,6 +27,7 @@ def qualify(
infer_schema: t.Optional[bool] = None, infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False, isolate_tables: bool = False,
qualify_columns: bool = True, qualify_columns: bool = True,
allow_partial_qualification: bool = False,
validate_qualify_columns: bool = True, validate_qualify_columns: bool = True,
quote_identifiers: bool = True, quote_identifiers: bool = True,
identify: bool = True, identify: bool = True,
@ -56,6 +57,7 @@ def qualify(
infer_schema: Whether to infer the schema if missing. infer_schema: Whether to infer the schema if missing.
isolate_tables: Whether to isolate table selects. isolate_tables: Whether to isolate table selects.
qualify_columns: Whether to qualify columns. qualify_columns: Whether to qualify columns.
allow_partial_qualification: Whether to allow partial qualification.
validate_qualify_columns: Whether to validate columns. validate_qualify_columns: Whether to validate columns.
quote_identifiers: Whether to run the quote_identifiers step. quote_identifiers: Whether to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries. This step is necessary to ensure correctness for case sensitive queries.
@ -90,6 +92,7 @@ def qualify(
expand_alias_refs=expand_alias_refs, expand_alias_refs=expand_alias_refs,
expand_stars=expand_stars, expand_stars=expand_stars,
infer_schema=infer_schema, infer_schema=infer_schema,
allow_partial_qualification=allow_partial_qualification,
) )
if quote_identifiers: if quote_identifiers:

View file

@ -22,6 +22,7 @@ def qualify_columns(
expand_alias_refs: bool = True, expand_alias_refs: bool = True,
expand_stars: bool = True, expand_stars: bool = True,
infer_schema: t.Optional[bool] = None, infer_schema: t.Optional[bool] = None,
allow_partial_qualification: bool = False,
) -> exp.Expression: ) -> exp.Expression:
""" """
Rewrite sqlglot AST to have fully qualified columns. Rewrite sqlglot AST to have fully qualified columns.
@ -41,6 +42,7 @@ def qualify_columns(
for most of the optimizer's rules to work; do not set to False unless you for most of the optimizer's rules to work; do not set to False unless you
know what you're doing! know what you're doing!
infer_schema: Whether to infer the schema if missing. infer_schema: Whether to infer the schema if missing.
allow_partial_qualification: Whether to allow partial qualification.
Returns: Returns:
The qualified expression. The qualified expression.
@ -68,7 +70,7 @@ def qualify_columns(
) )
_convert_columns_to_dots(scope, resolver) _convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver) _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
if not schema.empty and expand_alias_refs: if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver) _expand_alias_refs(scope, resolver)
@ -240,13 +242,21 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
def replace_columns( def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None: ) -> None:
if not node or (expand_only_groupby and not isinstance(node, exp.Group)): is_group_by = isinstance(node, exp.Group)
if not node or (expand_only_groupby and not is_group_by):
return return
for column in walk_in_scope(node, prune=lambda node: node.is_star): for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column): if not isinstance(column, exp.Column):
continue continue
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
if expand_only_groupby and is_group_by and column.parent is not node:
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1)) alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = ( double_agg = (
@ -273,9 +283,8 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
if simplified is not column: if simplified is not column:
column.replace(simplified) column.replace(simplified)
for i, projection in enumerate(scope.expression.selects): for i, projection in enumerate(expression.selects):
replace_columns(projection) replace_columns(projection)
if isinstance(projection, exp.Alias): if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = (projection.this, i + 1) alias_to_expression[projection.alias] = (projection.this, i + 1)
@ -434,7 +443,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache() scope.clear_cache()
def _qualify_columns(scope: Scope, resolver: Resolver) -> None: def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
"""Disambiguate columns, ensuring each column specifies a source""" """Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns: for column in scope.columns:
column_table = column.table column_table = column.table
@ -442,7 +451,12 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table and column_table in scope.sources: if column_table and column_table in scope.sources:
source_columns = resolver.get_source_columns(column_table) source_columns = resolver.get_source_columns(column_table)
if source_columns and column_name not in source_columns and "*" not in source_columns: if (
not allow_partial_qualification
and source_columns
and column_name not in source_columns
and "*" not in source_columns
):
raise OptimizeError(f"Unknown column: {column_name}") raise OptimizeError(f"Unknown column: {column_name}")
if not column_table: if not column_table:
@ -526,7 +540,7 @@ def _expand_stars(
) -> None: ) -> None:
"""Expand stars to lists of column selections""" """Expand stars to lists of column selections"""
new_selections = [] new_selections: t.List[exp.Expression] = []
except_columns: t.Dict[int, t.Set[str]] = {} except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
rename_columns: t.Dict[int, t.Dict[str, str]] = {} rename_columns: t.Dict[int, t.Dict[str, str]] = {}

View file

@ -562,8 +562,8 @@ def _traverse_scope(scope):
elif isinstance(expression, exp.DML): elif isinstance(expression, exp.DML):
yield from _traverse_ctes(scope) yield from _traverse_ctes(scope)
for query in find_all_in_scope(expression, exp.Query): for query in find_all_in_scope(expression, exp.Query):
# This check ensures we don't yield the CTE queries twice # This check ensures we don't yield the CTE/nested queries twice
if not isinstance(query.parent, exp.CTE): if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
return return
else: else:
@ -679,6 +679,8 @@ def _traverse_tables(scope):
expressions.extend(scope.expression.args.get("laterals") or []) expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions: for expression in expressions:
if isinstance(expression, exp.Final):
expression = expression.this
if isinstance(expression, exp.Table): if isinstance(expression, exp.Table):
table_name = expression.name table_name = expression.name
source_name = expression.alias_or_name source_name = expression.alias_or_name

View file

@ -206,6 +206,11 @@ COMPLEMENT_COMPARISONS = {
exp.NEQ: exp.EQ, exp.NEQ: exp.EQ,
} }
COMPLEMENT_SUBQUERY_PREDICATES = {
exp.All: exp.Any,
exp.Any: exp.All,
}
def simplify_not(expression): def simplify_not(expression):
""" """
@ -218,9 +223,12 @@ def simplify_not(expression):
if is_null(this): if is_null(this):
return exp.null() return exp.null()
if this.__class__ in COMPLEMENT_COMPARISONS: if this.__class__ in COMPLEMENT_COMPARISONS:
return COMPLEMENT_COMPARISONS[this.__class__]( right = this.expression
this=this.this, expression=this.expression complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
) if complement_subquery_predicate:
right = complement_subquery_predicate(this=right.this)
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
if isinstance(this, exp.Paren): if isinstance(this, exp.Paren):
condition = this.unnest() condition = this.unnest()
if isinstance(condition, exp.And): if isinstance(condition, exp.And):

View file

@ -1053,13 +1053,16 @@ class Parser(metaclass=_Parser):
ALTER_PARSERS = { ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(), "ADD": lambda self: self._parse_alter_table_add(),
"AS": lambda self: self._parse_select(),
"ALTER": lambda self: self._parse_alter_table_alter(), "ALTER": lambda self: self._parse_alter_table_alter(),
"CLUSTER BY": lambda self: self._parse_cluster(wrapped=True), "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(), "DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(), "RENAME": lambda self: self._parse_alter_table_rename(),
"SET": lambda self: self._parse_alter_table_set(), "SET": lambda self: self._parse_alter_table_set(),
"AS": lambda self: self._parse_select(), "SWAP": lambda self: self.expression(
exp.SwapTable, this=self._match(TokenType.WITH) and self._parse_table(schema=True)
),
} }
ALTER_ALTER_PARSERS = { ALTER_ALTER_PARSERS = {
@ -1222,6 +1225,10 @@ class Parser(metaclass=_Parser):
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()), **dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
} }
PROCEDURE_OPTIONS: OPTIONS_TYPE = {}
EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple())
KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = { KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
"NOT": ("ENFORCED",), "NOT": ("ENFORCED",),
"MATCH": ( "MATCH": (
@ -1286,6 +1293,11 @@ class Parser(metaclass=_Parser):
PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN} PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN}
# The style options for the DESCRIBE statement
DESCRIBE_STYLES = {"ANALYZE", "EXTENDED", "FORMATTED", "HISTORY"}
OPERATION_MODIFIERS: t.Set[str] = set()
STRICT_CAST = True STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False PREFIXED_PIVOT_COLUMNS = False
@ -2195,11 +2207,26 @@ class Parser(metaclass=_Parser):
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS), this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
) )
if self._match_texts(self.PROCEDURE_OPTIONS, advance=False):
return self.expression(
exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option)
)
if not self._next: if not self._next:
return None return None
return self._parse_withisolatedloading() return self._parse_withisolatedloading()
def _parse_procedure_option(self) -> exp.Expression | None:
if self._match_text_seq("EXECUTE", "AS"):
return self.expression(
exp.ExecuteAsProperty,
this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False)
or self._parse_string(),
)
return self._parse_var_from_options(self.PROCEDURE_OPTIONS)
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html # https://dev.mysql.com/doc/refman/8.0/en/create-view.html
def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
self._match(TokenType.EQ) self._match(TokenType.EQ)
@ -2567,7 +2594,7 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe: def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text kind = self._match_set(self.CREATABLES) and self._prev.text
style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper() style = self._match_texts(self.DESCRIBE_STYLES) and self._prev.text.upper()
if self._match(TokenType.DOT): if self._match(TokenType.DOT):
style = None style = None
self._retreat(self._index - 2) self._retreat(self._index - 2)
@ -2955,6 +2982,10 @@ class Parser(metaclass=_Parser):
if all_ and distinct: if all_ and distinct:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
operation_modifiers = []
while self._curr and self._match_texts(self.OPERATION_MODIFIERS):
operation_modifiers.append(exp.var(self._prev.text.upper()))
limit = self._parse_limit(top=True) limit = self._parse_limit(top=True)
projections = self._parse_projections() projections = self._parse_projections()
@ -2965,6 +2996,7 @@ class Parser(metaclass=_Parser):
distinct=distinct, distinct=distinct,
expressions=projections, expressions=projections,
limit=limit, limit=limit,
operation_modifiers=operation_modifiers or None,
) )
this.comments = comments this.comments = comments
@ -3400,6 +3432,10 @@ class Parser(metaclass=_Parser):
return None return None
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA):
kwargs["expressions"] = self._parse_csv(
lambda: self._parse_table(parse_bracket=parse_bracket)
)
if method: if method:
kwargs["method"] = method.text kwargs["method"] = method.text
@ -3420,7 +3456,7 @@ class Parser(metaclass=_Parser):
elif ( elif (
not (outer_apply or cross_apply) not (outer_apply or cross_apply)
and not isinstance(kwargs["this"], exp.Unnest) and not isinstance(kwargs["this"], exp.Unnest)
and not (kind and kind.token_type == TokenType.CROSS) and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY))
): ):
index = self._index index = self._index
joins: t.Optional[list] = list(self._parse_joins()) joins: t.Optional[list] = list(self._parse_joins())
@ -4470,7 +4506,7 @@ class Parser(metaclass=_Parser):
elif not self._match(TokenType.R_BRACKET, expression=this): elif not self._match(TokenType.R_BRACKET, expression=this):
self.raise_error("Expecting ]") self.raise_error("Expecting ]")
else: else:
this = self.expression(exp.In, this=this, field=self._parse_field()) this = self.expression(exp.In, this=this, field=self._parse_column())
return this return this
@ -5533,12 +5569,15 @@ class Parser(metaclass=_Parser):
return None return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]: def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.CONSTRAINT): this = self._match(TokenType.CONSTRAINT) and self._parse_id_var()
this = self._parse_id_var()
else:
this = None
if self._match_texts(self.CONSTRAINT_PARSERS): procedure_option_follows = (
self._match(TokenType.WITH, advance=False)
and self._next
and self._next.text.upper() in self.PROCEDURE_OPTIONS
)
if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression( return self.expression(
exp.ColumnConstraint, exp.ColumnConstraint,
this=this, this=this,
@ -6764,7 +6803,7 @@ class Parser(metaclass=_Parser):
self._retreat(index) self._retreat(index)
return self._parse_csv(self._parse_drop_column) return self._parse_csv(self._parse_drop_column)
def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]: def _parse_alter_table_rename(self) -> t.Optional[exp.AlterRename | exp.RenameColumn]:
if self._match(TokenType.COLUMN): if self._match(TokenType.COLUMN):
exists = self._parse_exists() exists = self._parse_exists()
old_column = self._parse_column() old_column = self._parse_column()
@ -6777,7 +6816,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists) return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
self._match_text_seq("TO") self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) return self.expression(exp.AlterRename, this=self._parse_table(schema=True))
def _parse_alter_table_set(self) -> exp.AlterSet: def _parse_alter_table_set(self) -> exp.AlterSet:
alter_set = self.expression(exp.AlterSet) alter_set = self.expression(exp.AlterSet)

View file

@ -107,6 +107,7 @@ LANGUAGE js AS
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`") select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF") self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
self.validate_identity("SELECT ARRAY_CONCAT([1])")
self.validate_identity("SELECT * FROM READ_CSV('bla.csv')") self.validate_identity("SELECT * FROM READ_CSV('bla.csv')")
self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)") self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
self.validate_identity("assert.true(1 = 1)") self.validate_identity("assert.true(1 = 1)")

View file

@ -2,6 +2,7 @@ from datetime import date
from sqlglot import exp, parse_one from sqlglot import exp, parse_one
from sqlglot.dialects import ClickHouse from sqlglot.dialects import ClickHouse
from sqlglot.expressions import convert from sqlglot.expressions import convert
from sqlglot.optimizer import traverse_scope
from tests.dialects.test_dialect import Validator from tests.dialects.test_dialect import Validator
from sqlglot.errors import ErrorLevel from sqlglot.errors import ErrorLevel
@ -28,6 +29,7 @@ class TestClickhouse(Validator):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta) self.assertIsNone(expr._meta)
self.validate_identity("CAST(1 AS Bool)")
self.validate_identity("SELECT toString(CHAR(104.1, 101, 108.9, 108.9, 111, 32))") self.validate_identity("SELECT toString(CHAR(104.1, 101, 108.9, 108.9, 111, 32))")
self.validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var) self.validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_identity("SELECT toFloat(like)") self.validate_identity("SELECT toFloat(like)")
@ -420,11 +422,6 @@ class TestClickhouse(Validator):
" GROUP BY loyalty ORDER BY loyalty ASC" " GROUP BY loyalty ORDER BY loyalty ASC"
}, },
) )
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
self.validate_identity(
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
)
self.validate_all( self.validate_all(
"SELECT quantile(0.5)(a)", "SELECT quantile(0.5)(a)",
read={"duckdb": "SELECT quantile(a, 0.5)"}, read={"duckdb": "SELECT quantile(a, 0.5)"},
@ -1100,3 +1097,36 @@ LIFETIME(MIN 0 MAX 0)""",
def test_grant(self): def test_grant(self):
self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION") self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION")
self.validate_identity("GRANT INSERT(x, y) ON db.table TO john") self.validate_identity("GRANT INSERT(x, y) ON db.table TO john")
def test_array_join(self):
expr = self.validate_identity(
"SELECT * FROM arrays_test ARRAY JOIN arr1, arrays_test.arr2 AS foo, ['a', 'b', 'c'] AS elem"
)
joins = expr.args["joins"]
self.assertEqual(len(joins), 1)
join = joins[0]
self.assertEqual(join.kind, "ARRAY")
self.assertIsInstance(join.this, exp.Column)
self.assertEqual(len(join.expressions), 2)
self.assertIsInstance(join.expressions[0], exp.Alias)
self.assertIsInstance(join.expressions[0].this, exp.Column)
self.assertIsInstance(join.expressions[1], exp.Alias)
self.assertIsInstance(join.expressions[1].this, exp.Array)
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
self.validate_identity(
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
)
self.validate_identity(
"SELECT * FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external1, ['a', 'b', 'c'] AS arr_external2, splitByString(',', 'asd,qwerty,zxc') AS arr_external3"
)
def test_traverse_scope(self):
sql = "SELECT * FROM t FINAL"
scopes = traverse_scope(parse_one(sql, dialect=self.dialect))
self.assertEqual(len(scopes), 1)
self.assertEqual(set(scopes[0].sources), {"t"})

View file

@ -7,6 +7,7 @@ class TestDatabricks(Validator):
dialect = "databricks" dialect = "databricks"
def test_databricks(self): def test_databricks(self):
self.validate_identity("SELECT t.current_time FROM t")
self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT") self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT")
self.validate_identity("DESCRIBE HISTORY a.b") self.validate_identity("DESCRIBE HISTORY a.b")
self.validate_identity("DESCRIBE history.tbl") self.validate_identity("DESCRIBE history.tbl")

View file

@ -1762,6 +1762,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"LEVENSHTEIN(col1, col2)", "LEVENSHTEIN(col1, col2)",
write={ write={
"bigquery": "EDIT_DISTANCE(col1, col2)",
"duckdb": "LEVENSHTEIN(col1, col2)", "duckdb": "LEVENSHTEIN(col1, col2)",
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)", "drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
@ -1772,6 +1773,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={ write={
"bigquery": "EDIT_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",

View file

@ -256,6 +256,9 @@ class TestDuckDB(Validator):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b" parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
) )
self.validate_identity("SELECT UNNEST([1, 2])").selects[0].assert_is(exp.UDTF)
self.validate_identity("'red' IN flags").args["field"].assert_is(exp.Column)
self.validate_identity("'red' IN tbl.flags")
self.validate_identity("CREATE TABLE tbl1 (u UNION(num INT, str TEXT))") self.validate_identity("CREATE TABLE tbl1 (u UNION(num INT, str TEXT))")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")

Some files were not shown because too many files have changed in this diff Show more