Merging upstream version 25.26.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
9138e4b92a
commit
829a709061
117 changed files with 49296 additions and 47316 deletions
69
CHANGELOG.md
69
CHANGELOG.md
|
@ -1,6 +1,72 @@
|
|||
Changelog
|
||||
=========
|
||||
|
||||
## [v25.25.1] - 2024-10-15
|
||||
### :bug: Bug Fixes
|
||||
- [`e6567ae`](https://github.com/tobymao/sqlglot/commit/e6567ae11650834874808a844a19836fbb9ee753) - small overload fix for ensure list taking None *(PR [#4248](https://github.com/tobymao/sqlglot/pull/4248) by [@benfdking](https://github.com/benfdking))*
|
||||
|
||||
|
||||
## [v25.25.0] - 2024-10-14
|
||||
### :boom: BREAKING CHANGES
|
||||
- due to [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*:
|
||||
|
||||
refactor set operation builders so they can work with N expressions (#4226)
|
||||
|
||||
- due to [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*:
|
||||
|
||||
Native annotations for string functions (#4231)
|
||||
|
||||
- due to [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*:
|
||||
|
||||
Native annotations for string functions (#4234)
|
||||
|
||||
- due to [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*:
|
||||
|
||||
produce exp.Column for the RHS of <value> IN <name> (#4239)
|
||||
|
||||
- due to [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*:
|
||||
|
||||
make Explode a UDTF subclass (#4242)
|
||||
|
||||
|
||||
### :sparkles: New Features
|
||||
- [`163e943`](https://github.com/tobymao/sqlglot/commit/163e943cdaf449599640c198f69e73d2398eb323) - **tsql**: SPLIT_PART function and conversion to PARSENAME in tsql *(PR [#4211](https://github.com/tobymao/sqlglot/pull/4211) by [@daihuynh](https://github.com/daihuynh))*
|
||||
- [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*
|
||||
- [`3f6ba3e`](https://github.com/tobymao/sqlglot/commit/3f6ba3e69c9ba92429d2b3b00cac33f45518aa56) - **clickhouse**: Support varlen arrays for ARRAY JOIN *(PR [#4229](https://github.com/tobymao/sqlglot/pull/4229) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- :arrow_lower_right: *addresses issue [#4227](https://github.com/tobymao/sqlglot/issues/4227) opened by [@brunorpinho](https://github.com/brunorpinho)*
|
||||
- [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - **bigquery**: Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - **bigquery**: Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- [`eeae25e`](https://github.com/tobymao/sqlglot/commit/eeae25e03a883671f9d5e514f9bd3021fb6c0d32) - support EXPLAIN in mysql *(PR [#4235](https://github.com/tobymao/sqlglot/pull/4235) by [@xiaoyu-meng-mxy](https://github.com/xiaoyu-meng-mxy))*
|
||||
- [`06748d9`](https://github.com/tobymao/sqlglot/commit/06748d93ccd232528003c37fdda25ae8163f3c18) - **mysql**: add support for operation modifiers like HIGH_PRIORITY *(PR [#4238](https://github.com/tobymao/sqlglot/pull/4238) by [@georgesittas](https://github.com/georgesittas))*
|
||||
- :arrow_lower_right: *addresses issue [#4236](https://github.com/tobymao/sqlglot/issues/4236) opened by [@asdfsx](https://github.com/asdfsx)*
|
||||
|
||||
### :bug: Bug Fixes
|
||||
- [`dcdec95`](https://github.com/tobymao/sqlglot/commit/dcdec95f986426ae90469baca993b47ac390081b) - Make exp.Update a DML node *(PR [#4223](https://github.com/tobymao/sqlglot/pull/4223) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- :arrow_lower_right: *fixes issue [#4221](https://github.com/tobymao/sqlglot/issues/4221) opened by [@rahul-ve](https://github.com/rahul-ve)*
|
||||
- [`79caf51`](https://github.com/tobymao/sqlglot/commit/79caf519987718390a086bee19fdc89f6094496c) - **clickhouse**: rename BOOLEAN type to Bool fixes [#4230](https://github.com/tobymao/sqlglot/pull/4230) *(commit by [@georgesittas](https://github.com/georgesittas))*
|
||||
- [`b26a3f6`](https://github.com/tobymao/sqlglot/commit/b26a3f67b7113802ba1b4b3b211431e98258dc15) - satisfy mypy *(commit by [@georgesittas](https://github.com/georgesittas))*
|
||||
- [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - **parser**: produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*
|
||||
- :arrow_lower_right: *fixes issue [#4237](https://github.com/tobymao/sqlglot/issues/4237) opened by [@rustyconover](https://github.com/rustyconover)*
|
||||
- [`daa6e78`](https://github.com/tobymao/sqlglot/commit/daa6e78e4b810eff826f995aa52f9e38197f1b7e) - **optimizer**: handle subquery predicate substitution correctly in de morgan's rule *(PR [#4240](https://github.com/tobymao/sqlglot/pull/4240) by [@georgesittas](https://github.com/georgesittas))*
|
||||
- [`c0a8355`](https://github.com/tobymao/sqlglot/commit/c0a83556acffcd77521f69bf51503a07310f749d) - **parser**: parse a column reference for the RHS of the IN clause *(PR [#4241](https://github.com/tobymao/sqlglot/pull/4241) by [@georgesittas](https://github.com/georgesittas))*
|
||||
|
||||
### :recycle: Refactors
|
||||
- [`0882f03`](https://github.com/tobymao/sqlglot/commit/0882f03d526f593b2d415e85b7d7a7c113721806) - Rename exp.RenameTable to exp.AlterRename *(PR [#4224](https://github.com/tobymao/sqlglot/pull/4224) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- :arrow_lower_right: *addresses issue [#4222](https://github.com/tobymao/sqlglot/issues/4222) opened by [@s1101010110](https://github.com/s1101010110)*
|
||||
- [`fd42b5c`](https://github.com/tobymao/sqlglot/commit/fd42b5cdaf9421abb11e71d82726536af09e3ae3) - Simplify PARSENAME <-> SPLIT_PART transpilation *(PR [#4225](https://github.com/tobymao/sqlglot/pull/4225) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*
|
||||
|
||||
|
||||
## [v25.24.5] - 2024-10-08
|
||||
### :sparkles: New Features
|
||||
- [`22a1684`](https://github.com/tobymao/sqlglot/commit/22a16848d80a2fa6d310f99d21f7d81f90eb9440) - **bigquery**: Native annotations for more math functions *(PR [#4212](https://github.com/tobymao/sqlglot/pull/4212) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- [`354cfff`](https://github.com/tobymao/sqlglot/commit/354cfff13ab30d01c6123fca74eed0669d238aa0) - add builder methods to exp.Update and add with_ arg to exp.update *(PR [#4217](https://github.com/tobymao/sqlglot/pull/4217) by [@brdbry](https://github.com/brdbry))*
|
||||
|
||||
### :bug: Bug Fixes
|
||||
- [`2c513b7`](https://github.com/tobymao/sqlglot/commit/2c513b71c7d4b1ff5c7c4e12d6c38694210b1a12) - Attach CTE comments before commas *(PR [#4218](https://github.com/tobymao/sqlglot/pull/4218) by [@VaggelisD](https://github.com/VaggelisD))*
|
||||
- :arrow_lower_right: *fixes issue [#4216](https://github.com/tobymao/sqlglot/issues/4216) opened by [@ajfriend](https://github.com/ajfriend)*
|
||||
|
||||
|
||||
## [v25.24.4] - 2024-10-04
|
||||
### :bug: Bug Fixes
|
||||
- [`484df7d`](https://github.com/tobymao/sqlglot/commit/484df7d50df5cb314943e1810db18a7d7d5bb3eb) - tsql union with limit *(commit by [@tobymao](https://github.com/tobymao))*
|
||||
|
@ -4969,3 +5035,6 @@ Changelog
|
|||
[v25.24.2]: https://github.com/tobymao/sqlglot/compare/v25.24.1...v25.24.2
|
||||
[v25.24.3]: https://github.com/tobymao/sqlglot/compare/v25.24.2...v25.24.3
|
||||
[v25.24.4]: https://github.com/tobymao/sqlglot/compare/v25.24.3...v25.24.4
|
||||
[v25.24.5]: https://github.com/tobymao/sqlglot/compare/v25.24.4...v25.24.5
|
||||
[v25.25.0]: https://github.com/tobymao/sqlglot/compare/v25.24.5...v25.25.0
|
||||
[v25.25.1]: https://github.com/tobymao/sqlglot/compare/v25.25.0...v25.25.1
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
24254
docs/sqlglot/parser.html
24254
docs/sqlglot/parser.html
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -321,7 +321,24 @@ class BigQuery(Dialect):
|
|||
expr_type: lambda self, e: _annotate_math_functions(self, e)
|
||||
for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round)
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_by_args(e, "this")
|
||||
for expr_type in (
|
||||
exp.Left,
|
||||
exp.Right,
|
||||
exp.Lower,
|
||||
exp.Upper,
|
||||
exp.Pad,
|
||||
exp.Trim,
|
||||
exp.RegexpExtract,
|
||||
exp.RegexpReplace,
|
||||
exp.Repeat,
|
||||
exp.Substring,
|
||||
)
|
||||
},
|
||||
exp.Concat: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
}
|
||||
|
||||
def normalize_identifier(self, expression: E) -> E:
|
||||
|
@ -716,6 +733,7 @@ class BigQuery(Dialect):
|
|||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
exp.Levenshtein: rename_func("EDIT_DISTANCE"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
|
|
|
@ -603,6 +603,12 @@ class ClickHouse(Dialect):
|
|||
if join:
|
||||
join.set("global", join.args.pop("method", None))
|
||||
|
||||
# tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table`
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/array-join
|
||||
if join.kind == "ARRAY":
|
||||
for table in join.find_all(exp.Table):
|
||||
table.replace(table.to_column())
|
||||
|
||||
return join
|
||||
|
||||
def _parse_function(
|
||||
|
@ -627,15 +633,18 @@ class ClickHouse(Dialect):
|
|||
)
|
||||
|
||||
if parts:
|
||||
params = self._parse_func_params(func)
|
||||
anon_func: exp.Anonymous = t.cast(exp.Anonymous, func)
|
||||
params = self._parse_func_params(anon_func)
|
||||
|
||||
kwargs = {
|
||||
"this": func.this,
|
||||
"expressions": func.expressions,
|
||||
"this": anon_func.this,
|
||||
"expressions": anon_func.expressions,
|
||||
}
|
||||
if parts[1]:
|
||||
kwargs["parts"] = parts
|
||||
exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
|
||||
exp_class: t.Type[exp.Expression] = (
|
||||
exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
|
||||
)
|
||||
else:
|
||||
exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc
|
||||
|
||||
|
@ -825,6 +834,7 @@ class ClickHouse(Dialect):
|
|||
**generator.Generator.TYPE_MAPPING,
|
||||
**STRING_TYPE_MAPPING,
|
||||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.BOOLEAN: "Bool",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATE32: "Date32",
|
||||
exp.DataType.Type.DATETIME: "DateTime",
|
||||
|
|
|
@ -588,6 +588,7 @@ class Dialect(metaclass=_Dialect):
|
|||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.ToDouble,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
|
@ -1697,3 +1698,18 @@ def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract:
|
|||
expression=seq_get(args, 1),
|
||||
group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
|
||||
)
|
||||
|
||||
|
||||
def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
this=exp.Unnest(
|
||||
expressions=[expression.this.this],
|
||||
alias=expression.args.get("alias"),
|
||||
offset=isinstance(expression.this, exp.Posexplode),
|
||||
),
|
||||
kind="cross",
|
||||
)
|
||||
)
|
||||
return self.lateral_sql(expression)
|
||||
|
|
|
@ -35,6 +35,7 @@ from sqlglot.dialects.dialect import (
|
|||
unit_to_str,
|
||||
sha256_sql,
|
||||
build_regexp_extract,
|
||||
explode_to_unnest_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -538,6 +539,7 @@ class DuckDB(Dialect):
|
|||
exp.JSONExtract: _arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: _arrow_json_extract_sql,
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.Lateral: explode_to_unnest_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
|
||||
|
|
|
@ -333,6 +333,9 @@ class Hive(Dialect):
|
|||
"TRANSFORM": lambda self: self._parse_transform(),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = parser.Parser.NO_PAREN_FUNCTIONS.copy()
|
||||
NO_PAREN_FUNCTIONS.pop(TokenType.CURRENT_TIME)
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"SERDEPROPERTIES": lambda self: exp.SerdeProperties(
|
||||
|
|
|
@ -187,6 +187,9 @@ class MySQL(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CHARSET": TokenType.CHARACTER_SET,
|
||||
# The DESCRIBE and EXPLAIN statements are synonyms.
|
||||
# https://dev.mysql.com/doc/refman/8.4/en/explain.html
|
||||
"EXPLAIN": TokenType.DESCRIBE,
|
||||
"FORCE": TokenType.FORCE,
|
||||
"IGNORE": TokenType.IGNORE,
|
||||
"KEY": TokenType.KEY,
|
||||
|
@ -453,6 +456,17 @@ class MySQL(Dialect):
|
|||
TokenType.SET,
|
||||
}
|
||||
|
||||
# SELECT [ ALL | DISTINCT | DISTINCTROW ] [ <OPERATION_MODIFIERS> ]
|
||||
OPERATION_MODIFIERS = {
|
||||
"HIGH_PRIORITY",
|
||||
"STRAIGHT_JOIN",
|
||||
"SQL_SMALL_RESULT",
|
||||
"SQL_BIG_RESULT",
|
||||
"SQL_BUFFER_RESULT",
|
||||
"SQL_NO_CACHE",
|
||||
"SQL_CALC_FOUND_ROWS",
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRING_ALIASES = True
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import OPTIONS_TYPE, build_coalesce
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.errors import ParseError
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
@ -205,6 +206,57 @@ class Oracle(Dialect):
|
|||
)
|
||||
|
||||
def _parse_hint(self) -> t.Optional[exp.Hint]:
|
||||
start_index = self._index
|
||||
should_fallback_to_string = False
|
||||
|
||||
if not self._match(TokenType.HINT):
|
||||
return None
|
||||
|
||||
hints = []
|
||||
|
||||
try:
|
||||
for hint in iter(
|
||||
lambda: self._parse_csv(
|
||||
lambda: self._parse_hint_function_call() or self._parse_var(upper=True),
|
||||
),
|
||||
[],
|
||||
):
|
||||
hints.extend(hint)
|
||||
except ParseError:
|
||||
should_fallback_to_string = True
|
||||
|
||||
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
|
||||
should_fallback_to_string = True
|
||||
|
||||
if should_fallback_to_string:
|
||||
self._retreat(start_index)
|
||||
return self._parse_hint_fallback_to_string()
|
||||
|
||||
return self.expression(exp.Hint, expressions=hints)
|
||||
|
||||
def _parse_hint_function_call(self) -> t.Optional[exp.Expression]:
|
||||
if not self._curr or not self._next or self._next.token_type != TokenType.L_PAREN:
|
||||
return None
|
||||
|
||||
this = self._curr.text
|
||||
|
||||
self._advance(2)
|
||||
args = self._parse_hint_args()
|
||||
this = self.expression(exp.Anonymous, this=this, expressions=args)
|
||||
self._match_r_paren(this)
|
||||
return this
|
||||
|
||||
def _parse_hint_args(self):
|
||||
args = []
|
||||
result = self._parse_var()
|
||||
|
||||
while result:
|
||||
args.append(result)
|
||||
result = self._parse_var()
|
||||
|
||||
return args
|
||||
|
||||
def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]:
|
||||
if self._match(TokenType.HINT):
|
||||
start = self._curr
|
||||
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
|
||||
|
@ -271,6 +323,7 @@ class Oracle(Dialect):
|
|||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
QUERY_HINT_SEP = " "
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -370,3 +423,15 @@ class Oracle(Dialect):
|
|||
return f"{self.seg(into)} {self.sql(expression, 'this')}"
|
||||
|
||||
return f"{self.seg(into)} {self.expressions(expression)}"
|
||||
|
||||
def hint_sql(self, expression: exp.Hint) -> str:
|
||||
expressions = []
|
||||
|
||||
for expression in expression.expressions:
|
||||
if isinstance(expression, exp.Anonymous):
|
||||
formatted_args = self.format_args(*expression.expressions, sep=" ")
|
||||
expressions.append(f"{self.sql(expression, 'this')}({formatted_args})")
|
||||
else:
|
||||
expressions.append(self.sql(expression))
|
||||
|
||||
return f" /*+ {self.expressions(sqls=expressions, sep=self.QUERY_HINT_SEP).strip()} */"
|
||||
|
|
|
@ -30,6 +30,7 @@ from sqlglot.dialects.dialect import (
|
|||
unit_to_str,
|
||||
sequence_sql,
|
||||
build_regexp_extract,
|
||||
explode_to_unnest_sql,
|
||||
)
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
@ -40,21 +41,6 @@ from sqlglot.transforms import unqualify_columns
|
|||
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub]
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
this=exp.Unnest(
|
||||
expressions=[expression.this.this],
|
||||
alias=expression.args.get("alias"),
|
||||
offset=isinstance(expression.this, exp.Posexplode),
|
||||
),
|
||||
kind="cross",
|
||||
)
|
||||
)
|
||||
return self.lateral_sql(expression)
|
||||
|
||||
|
||||
def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
|
||||
regex = r"(\w)(\w*)"
|
||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
@ -340,16 +326,17 @@ class Presto(Dialect):
|
|||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
exp.DataType.Type.TEXT: "VARCHAR",
|
||||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.STRUCT: "ROW",
|
||||
exp.DataType.Type.BIT: "BOOLEAN",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.DATETIME64: "TIMESTAMP",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.STRUCT: "ROW",
|
||||
exp.DataType.Type.TEXT: "VARCHAR",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -400,9 +387,6 @@ class Presto(Dialect):
|
|||
exp.GenerateSeries: sequence_sql,
|
||||
exp.GenerateDateArray: sequence_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
),
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
|
@ -410,7 +394,7 @@ class Presto(Dialect):
|
|||
exp.Last: _first_last_sql,
|
||||
exp.LastValue: _first_last_sql,
|
||||
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Lateral: explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
|
@ -694,3 +678,10 @@ class Presto(Dialect):
|
|||
expr = "".join(segments)
|
||||
|
||||
return f"{this}{expr}"
|
||||
|
||||
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
|
||||
return self.func(
|
||||
"ARRAY_JOIN",
|
||||
self.func("ARRAY_AGG", expression.this),
|
||||
expression.args.get("separator"),
|
||||
)
|
||||
|
|
|
@ -41,15 +41,23 @@ def _build_datetime(
|
|||
if isinstance(value, exp.Literal):
|
||||
# Converts calls like `TO_TIME('01:02:03')` into casts
|
||||
if len(args) == 1 and value.is_string and not int_value:
|
||||
return exp.cast(value, kind)
|
||||
return (
|
||||
exp.TryCast(this=value, to=exp.DataType.build(kind))
|
||||
if safe
|
||||
else exp.cast(value, kind)
|
||||
)
|
||||
|
||||
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
|
||||
# cases so we can transpile them, since they're relatively common
|
||||
if kind == exp.DataType.Type.TIMESTAMP:
|
||||
if int_value:
|
||||
if int_value and not safe:
|
||||
# TRY_TO_TIMESTAMP('integer') is not parsed into exp.UnixToTime as
|
||||
# it's not easily transpilable
|
||||
return exp.UnixToTime(this=value, scale=seq_get(args, 1))
|
||||
if not is_float(value.this):
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
expr.set("safe", safe)
|
||||
return expr
|
||||
|
||||
if kind == exp.DataType.Type.DATE and not int_value:
|
||||
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
|
||||
|
@ -345,6 +353,9 @@ class Snowflake(Dialect):
|
|||
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
|
||||
"TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
|
||||
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
|
||||
"TRY_TO_TIMESTAMP": _build_datetime(
|
||||
"TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
|
||||
),
|
||||
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
|
||||
"TO_NUMBER": lambda args: exp.ToNumber(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -384,7 +395,6 @@ class Snowflake(Dialect):
|
|||
expressions=self._parse_csv(self._parse_id_var),
|
||||
unset=True,
|
||||
),
|
||||
"SWAP": lambda self: self._parse_alter_table_swap(),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
|
@ -654,10 +664,6 @@ class Snowflake(Dialect):
|
|||
},
|
||||
)
|
||||
|
||||
def _parse_alter_table_swap(self) -> exp.SwapTable:
|
||||
self._match_text_seq("WITH")
|
||||
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
|
||||
|
||||
def _parse_location_property(self) -> exp.LocationProperty:
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.LocationProperty, this=self._parse_location_path())
|
||||
|
@ -828,7 +834,6 @@ class Snowflake(Dialect):
|
|||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimeAdd: date_delta_sql("TIMEADD"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
|
@ -842,6 +847,7 @@ class Snowflake(Dialect):
|
|||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.ToArray: rename_func("TO_ARRAY"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.ToDouble: rename_func("TO_DOUBLE"),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func(
|
||||
|
@ -1036,10 +1042,6 @@ class Snowflake(Dialect):
|
|||
increment = f" INCREMENT {increment}" if increment else ""
|
||||
return f"AUTOINCREMENT{start}{increment}"
|
||||
|
||||
def swaptable_sql(self, expression: exp.SwapTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
return f"SWAP WITH {this}"
|
||||
|
||||
def cluster_sql(self, expression: exp.Cluster) -> str:
|
||||
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
|
||||
|
||||
|
@ -1074,3 +1076,9 @@ class Snowflake(Dialect):
|
|||
tag = f" TAG {tag}" if tag else ""
|
||||
|
||||
return f"SET{exprs}{file_format}{copy_options}{tag}"
|
||||
|
||||
def strtotime_sql(self, expression: exp.StrToTime):
|
||||
safe_prefix = "TRY_" if expression.args.get("safe") else ""
|
||||
return self.func(
|
||||
f"{safe_prefix}TO_TIMESTAMP", expression.this, self.format_time(expression)
|
||||
)
|
||||
|
|
|
@ -10,11 +10,15 @@ class Trino(Presto):
|
|||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
||||
class Parser(Presto.Parser):
|
||||
FUNCTION_PARSERS = {
|
||||
**Presto.Parser.FUNCTION_PARSERS,
|
||||
"TRIM": lambda self: self._parse_trim(),
|
||||
"JSON_QUERY": lambda self: self._parse_json_query(),
|
||||
"LISTAGG": lambda self: self._parse_string_agg(),
|
||||
}
|
||||
|
||||
JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
|
||||
|
@ -65,5 +69,14 @@ class Trino(Presto):
|
|||
|
||||
return self.func("JSON_QUERY", expression.this, json_path + option)
|
||||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
|
||||
this = expression.this
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
||||
if isinstance(this, exp.Order):
|
||||
if this.this:
|
||||
this = this.this.pop()
|
||||
|
||||
return f"LISTAGG({self.format_args(this, separator)}) WITHIN GROUP ({self.sql(expression.this).lstrip()})"
|
||||
|
||||
return super().groupconcat_sql(expression)
|
||||
|
|
|
@ -324,6 +324,25 @@ def _build_with_arg_as_text(
|
|||
return _parse
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16
|
||||
def _build_parsename(args: t.List) -> exp.SplitPart | exp.Anonymous:
|
||||
# PARSENAME(...) will be stored into exp.SplitPart if:
|
||||
# - All args are literals
|
||||
# - The part index (2nd arg) is <= 4 (max valid value, otherwise TSQL returns NULL)
|
||||
if len(args) == 2 and all(isinstance(arg, exp.Literal) for arg in args):
|
||||
this = args[0]
|
||||
part_index = args[1]
|
||||
split_count = len(this.name.split("."))
|
||||
if split_count <= 4:
|
||||
return exp.SplitPart(
|
||||
this=this,
|
||||
delimiter=exp.Literal.string("."),
|
||||
part_index=exp.Literal.number(split_count + 1 - part_index.to_py()),
|
||||
)
|
||||
|
||||
return exp.Anonymous(this="PARSENAME", expressions=args)
|
||||
|
||||
|
||||
def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract:
|
||||
if len(args) == 1:
|
||||
# The default value for path is '$'. As a result, if you don't provide a
|
||||
|
@ -543,6 +562,7 @@ class TSQL(Dialect):
|
|||
"LEN": _build_with_arg_as_text(exp.Length),
|
||||
"LEFT": _build_with_arg_as_text(exp.Left),
|
||||
"RIGHT": _build_with_arg_as_text(exp.Right),
|
||||
"PARSENAME": _build_parsename,
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
|
@ -554,6 +574,10 @@ class TSQL(Dialect):
|
|||
|
||||
JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
|
||||
|
||||
PROCEDURE_OPTIONS = dict.fromkeys(
|
||||
("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple()
|
||||
)
|
||||
|
||||
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
|
||||
TokenType.TABLE,
|
||||
*parser.Parser.TYPE_TOKENS,
|
||||
|
@ -699,7 +723,11 @@ class TSQL(Dialect):
|
|||
):
|
||||
return this
|
||||
|
||||
if not self._match(TokenType.WITH, advance=False):
|
||||
expressions = self._parse_csv(self._parse_function_parameter)
|
||||
else:
|
||||
expressions = None
|
||||
|
||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||
|
||||
def _parse_id_var(
|
||||
|
@ -954,6 +982,27 @@ class TSQL(Dialect):
|
|||
self.unsupported("LATERAL clause is not supported.")
|
||||
return "LATERAL"
|
||||
|
||||
def splitpart_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str:
|
||||
this = expression.this
|
||||
split_count = len(this.name.split("."))
|
||||
delimiter = expression.args.get("delimiter")
|
||||
part_index = expression.args.get("part_index")
|
||||
|
||||
if (
|
||||
not all(isinstance(arg, exp.Literal) for arg in (this, delimiter, part_index))
|
||||
or (delimiter and delimiter.name != ".")
|
||||
or not part_index
|
||||
or split_count > 4
|
||||
):
|
||||
self.unsupported(
|
||||
"SPLIT_PART can be transpiled to PARSENAME only for '.' delimiter and literal values"
|
||||
)
|
||||
return ""
|
||||
|
||||
return self.func(
|
||||
"PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py())
|
||||
)
|
||||
|
||||
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
|
@ -1166,7 +1215,7 @@ class TSQL(Dialect):
|
|||
|
||||
def alter_sql(self, expression: exp.Alter) -> str:
|
||||
action = seq_get(expression.args.get("actions") or [], 0)
|
||||
if isinstance(action, exp.RenameTable):
|
||||
if isinstance(action, exp.AlterRename):
|
||||
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
|
||||
return super().alter_sql(expression)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from dataclasses import dataclass
|
|||
from heapq import heappop, heappush
|
||||
|
||||
from sqlglot import Dialect, expressions as exp
|
||||
from sqlglot.helper import ensure_list
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
@ -185,7 +185,7 @@ class ChangeDistiller:
|
|||
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
|
||||
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
|
||||
|
||||
matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
|
||||
matching_set = self._compute_matching_set() | set(pre_matched_nodes.items())
|
||||
return self._generate_edit_script(matching_set, delta_only)
|
||||
|
||||
def _generate_edit_script(
|
||||
|
@ -201,6 +201,7 @@ class ChangeDistiller:
|
|||
for kept_source_node_id, kept_target_node_id in matching_set:
|
||||
source_node = self._source_index[kept_source_node_id]
|
||||
target_node = self._target_index[kept_target_node_id]
|
||||
|
||||
if (
|
||||
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
|
||||
or source_node == target_node
|
||||
|
@ -208,7 +209,13 @@ class ChangeDistiller:
|
|||
edit_script.extend(
|
||||
self._generate_move_edits(source_node, target_node, matching_set)
|
||||
)
|
||||
if not delta_only:
|
||||
|
||||
source_non_expression_leaves = dict(_get_non_expression_leaves(source_node))
|
||||
target_non_expression_leaves = dict(_get_non_expression_leaves(target_node))
|
||||
|
||||
if source_non_expression_leaves != target_non_expression_leaves:
|
||||
edit_script.append(Update(source_node, target_node))
|
||||
elif not delta_only:
|
||||
edit_script.append(Keep(source_node, target_node))
|
||||
else:
|
||||
edit_script.append(Update(source_node, target_node))
|
||||
|
@ -246,8 +253,8 @@ class ChangeDistiller:
|
|||
source_node = self._source_index[source_node_id]
|
||||
target_node = self._target_index[target_node_id]
|
||||
if _is_same_type(source_node, target_node):
|
||||
source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
|
||||
target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
|
||||
source_leaf_ids = {id(l) for l in _get_expression_leaves(source_node)}
|
||||
target_leaf_ids = {id(l) for l in _get_expression_leaves(target_node)}
|
||||
|
||||
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
|
||||
if max_leaves_num:
|
||||
|
@ -277,10 +284,10 @@ class ChangeDistiller:
|
|||
|
||||
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
|
||||
candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
|
||||
source_leaves = list(_get_leaves(self._source))
|
||||
target_leaves = list(_get_leaves(self._target))
|
||||
for source_leaf in source_leaves:
|
||||
for target_leaf in target_leaves:
|
||||
source_expression_leaves = list(_get_expression_leaves(self._source))
|
||||
target_expression_leaves = list(_get_expression_leaves(self._target))
|
||||
for source_leaf in source_expression_leaves:
|
||||
for target_leaf in target_expression_leaves:
|
||||
if _is_same_type(source_leaf, target_leaf):
|
||||
similarity_score = self._dice_coefficient(source_leaf, target_leaf)
|
||||
if similarity_score >= self.f:
|
||||
|
@ -338,18 +345,28 @@ class ChangeDistiller:
|
|||
return bigram_histo
|
||||
|
||||
|
||||
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||
def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||
has_child_exprs = False
|
||||
|
||||
for node in expression.iter_expressions():
|
||||
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
|
||||
has_child_exprs = True
|
||||
yield from _get_leaves(node)
|
||||
yield from _get_expression_leaves(node)
|
||||
|
||||
if not has_child_exprs:
|
||||
yield expression
|
||||
|
||||
|
||||
def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]:
|
||||
for arg, value in expression.args.items():
|
||||
if isinstance(value, exp.Expression) or (
|
||||
isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression)
|
||||
):
|
||||
continue
|
||||
|
||||
yield (arg, value)
|
||||
|
||||
|
||||
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
|
||||
if type(source) is type(target):
|
||||
if isinstance(source, exp.Join):
|
||||
|
@ -372,16 +389,12 @@ def _parent_similarity_score(
|
|||
return 1 + _parent_similarity_score(source.parent, target.parent)
|
||||
|
||||
|
||||
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
|
||||
args: t.List[t.Union[exp.Expression, t.List]] = []
|
||||
if expression:
|
||||
for a in expression.args.values():
|
||||
args.extend(ensure_list(a))
|
||||
return [
|
||||
a
|
||||
for a in args
|
||||
if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
]
|
||||
def _expression_only_args(expression: exp.Expression) -> t.Iterator[exp.Expression]:
|
||||
yield from (
|
||||
arg
|
||||
for arg in expression.iter_expressions()
|
||||
if not isinstance(arg, IGNORED_LEAF_EXPRESSION_TYPES)
|
||||
)
|
||||
|
||||
|
||||
def _lcs(
|
||||
|
|
|
@ -404,9 +404,9 @@ class Expression(metaclass=_Expression):
|
|||
def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]:
|
||||
"""Yields the key and expression for all arguments, exploding list args."""
|
||||
# remove tuple when python 3.7 is deprecated
|
||||
for vs in reversed(tuple(self.args.values())) if reverse else self.args.values():
|
||||
for vs in reversed(tuple(self.args.values())) if reverse else self.args.values(): # type: ignore
|
||||
if type(vs) is list:
|
||||
for v in reversed(vs) if reverse else vs:
|
||||
for v in reversed(vs) if reverse else vs: # type: ignore
|
||||
if hasattr(v, "parent"):
|
||||
yield v
|
||||
else:
|
||||
|
@ -1247,7 +1247,7 @@ class Query(Expression):
|
|||
)
|
||||
|
||||
def union(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Union:
|
||||
"""
|
||||
Builds a UNION expression.
|
||||
|
@ -1258,8 +1258,8 @@ class Query(Expression):
|
|||
'SELECT * FROM foo UNION SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings.
|
||||
If `Expression` instances are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
@ -1267,10 +1267,10 @@ class Query(Expression):
|
|||
Returns:
|
||||
The new Union expression.
|
||||
"""
|
||||
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
return union(self, *expressions, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
def intersect(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Intersect:
|
||||
"""
|
||||
Builds an INTERSECT expression.
|
||||
|
@ -1281,8 +1281,8 @@ class Query(Expression):
|
|||
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings.
|
||||
If `Expression` instances are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
@ -1290,10 +1290,10 @@ class Query(Expression):
|
|||
Returns:
|
||||
The new Intersect expression.
|
||||
"""
|
||||
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
return intersect(self, *expressions, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
def except_(
|
||||
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
) -> Except:
|
||||
"""
|
||||
Builds an EXCEPT expression.
|
||||
|
@ -1304,8 +1304,8 @@ class Query(Expression):
|
|||
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
expression: the SQL code string.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings.
|
||||
If `Expression` instance are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
@ -1313,7 +1313,7 @@ class Query(Expression):
|
|||
Returns:
|
||||
The new Except expression.
|
||||
"""
|
||||
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
|
||||
return except_(self, *expressions, distinct=distinct, dialect=dialect, **opts)
|
||||
|
||||
|
||||
class UDTF(DerivedTable):
|
||||
|
@ -1697,7 +1697,7 @@ class RenameColumn(Expression):
|
|||
arg_types = {"this": True, "to": True, "exists": False}
|
||||
|
||||
|
||||
class RenameTable(Expression):
|
||||
class AlterRename(Expression):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -2400,6 +2400,7 @@ class Join(Expression):
|
|||
"global": False,
|
||||
"hint": False,
|
||||
"match_condition": False, # Snowflake
|
||||
"expressions": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2995,6 +2996,10 @@ class WithSystemVersioningProperty(Property):
|
|||
}
|
||||
|
||||
|
||||
class WithProcedureOptions(Property):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -3213,10 +3218,18 @@ class Table(Expression):
|
|||
|
||||
def to_column(self, copy: bool = True) -> Alias | Column | Dot:
|
||||
parts = self.parts
|
||||
last_part = parts[-1]
|
||||
|
||||
if isinstance(last_part, Identifier):
|
||||
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
|
||||
else:
|
||||
# This branch will be reached if a function or array is wrapped in a `Table`
|
||||
col = last_part
|
||||
|
||||
alias = self.args.get("alias")
|
||||
if alias:
|
||||
col = alias_(col, alias.this, copy=copy)
|
||||
|
||||
return col
|
||||
|
||||
|
||||
|
@ -3278,7 +3291,7 @@ class Intersect(SetOperation):
|
|||
pass
|
||||
|
||||
|
||||
class Update(Expression):
|
||||
class Update(DML):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": False,
|
||||
|
@ -3526,6 +3539,7 @@ class Select(Query):
|
|||
"distinct": False,
|
||||
"into": False,
|
||||
"from": False,
|
||||
"operation_modifiers": False,
|
||||
**QUERY_MODIFIERS,
|
||||
}
|
||||
|
||||
|
@ -5184,6 +5198,14 @@ class ToNumber(Func):
|
|||
}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/to_double
|
||||
class ToDouble(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"format": False,
|
||||
}
|
||||
|
||||
|
||||
class Columns(Func):
|
||||
arg_types = {"this": True, "unpack": False}
|
||||
|
||||
|
@ -5641,7 +5663,7 @@ class Exp(Func):
|
|||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/flatten
|
||||
class Explode(Func):
|
||||
class Explode(Func, UDTF):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
@ -6248,6 +6270,11 @@ class Split(Func):
|
|||
arg_types = {"this": True, "expression": True, "limit": False}
|
||||
|
||||
|
||||
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html
|
||||
class SplitPart(Func):
|
||||
arg_types = {"this": True, "delimiter": True, "part_index": True}
|
||||
|
||||
|
||||
# Start may be omitted in the case of postgres
|
||||
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
|
||||
class Substring(Func):
|
||||
|
@ -6857,26 +6884,37 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
|
|||
return Paren(this=expression) if isinstance(expression, kind) else expression
|
||||
|
||||
|
||||
def _apply_set_operation(
|
||||
*expressions: ExpOrStr,
|
||||
set_operation: t.Type[S],
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> S:
|
||||
return reduce(
|
||||
lambda x, y: set_operation(this=x, expression=y, distinct=distinct),
|
||||
(maybe_parse(e, dialect=dialect, copy=copy, **opts) for e in expressions),
|
||||
)
|
||||
|
||||
|
||||
def union(
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
*expressions: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Union:
|
||||
"""
|
||||
Initializes a syntax tree from one UNION expression.
|
||||
Initializes a syntax tree for the `UNION` operation.
|
||||
|
||||
Example:
|
||||
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo UNION SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left: the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right: the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings, corresponding to the `UNION`'s operands.
|
||||
If `Expression` instances are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether to copy the expression.
|
||||
|
@ -6885,32 +6923,29 @@ def union(
|
|||
Returns:
|
||||
The new Union instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Union(this=left, expression=right, distinct=distinct)
|
||||
assert len(expressions) >= 2, "At least two expressions are required by `union`."
|
||||
return _apply_set_operation(
|
||||
*expressions, set_operation=Union, distinct=distinct, dialect=dialect, copy=copy, **opts
|
||||
)
|
||||
|
||||
|
||||
def intersect(
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
*expressions: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Intersect:
|
||||
"""
|
||||
Initializes a syntax tree from one INTERSECT expression.
|
||||
Initializes a syntax tree for the `INTERSECT` operation.
|
||||
|
||||
Example:
|
||||
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left: the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right: the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings, corresponding to the `INTERSECT`'s operands.
|
||||
If `Expression` instances are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether to copy the expression.
|
||||
|
@ -6919,32 +6954,29 @@ def intersect(
|
|||
Returns:
|
||||
The new Intersect instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Intersect(this=left, expression=right, distinct=distinct)
|
||||
assert len(expressions) >= 2, "At least two expressions are required by `intersect`."
|
||||
return _apply_set_operation(
|
||||
*expressions, set_operation=Intersect, distinct=distinct, dialect=dialect, copy=copy, **opts
|
||||
)
|
||||
|
||||
|
||||
def except_(
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
*expressions: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Except:
|
||||
"""
|
||||
Initializes a syntax tree from one EXCEPT expression.
|
||||
Initializes a syntax tree for the `EXCEPT` operation.
|
||||
|
||||
Example:
|
||||
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
|
||||
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
|
||||
|
||||
Args:
|
||||
left: the SQL code string corresponding to the left-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
right: the SQL code string corresponding to the right-hand side.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
expressions: the SQL code strings, corresponding to the `EXCEPT`'s operands.
|
||||
If `Expression` instances are passed, they will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether to copy the expression.
|
||||
|
@ -6953,10 +6985,10 @@ def except_(
|
|||
Returns:
|
||||
The new Except instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Except(this=left, expression=right, distinct=distinct)
|
||||
assert len(expressions) >= 2, "At least two expressions are required by `except_`."
|
||||
return _apply_set_operation(
|
||||
*expressions, set_operation=Except, distinct=distinct, dialect=dialect, copy=copy, **opts
|
||||
)
|
||||
|
||||
|
||||
def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
|
||||
|
@ -7410,15 +7442,9 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
interval = interval.this
|
||||
|
||||
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
|
||||
|
||||
if not interval_parts:
|
||||
raise ValueError("Invalid interval string.")
|
||||
|
||||
return Interval(
|
||||
this=Literal.string(interval_parts.group(1)),
|
||||
unit=Var(this=interval_parts.group(2).upper()),
|
||||
)
|
||||
interval = maybe_parse(f"INTERVAL {interval}")
|
||||
assert isinstance(interval, Interval)
|
||||
return interval
|
||||
|
||||
|
||||
def to_table(
|
||||
|
@ -7795,7 +7821,7 @@ def rename_table(
|
|||
this=old_table,
|
||||
kind="TABLE",
|
||||
actions=[
|
||||
RenameTable(this=new_table),
|
||||
AlterRename(this=new_table),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -185,6 +185,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}",
|
||||
exp.StreamingTableProperty: lambda *_: "STREAMING",
|
||||
exp.StrictProperty: lambda *_: "STRICT",
|
||||
exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}",
|
||||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
|
@ -200,6 +201,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
|
||||
exp.VolatileProperty: lambda *_: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}",
|
||||
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
|
||||
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
|
||||
}
|
||||
|
@ -564,6 +566,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
@ -2144,6 +2147,10 @@ class Generator(metaclass=_Generator):
|
|||
this = expression.this
|
||||
this_sql = self.sql(this)
|
||||
|
||||
exprs = self.expressions(expression)
|
||||
if exprs:
|
||||
this_sql = f"{this_sql},{self.seg(exprs)}"
|
||||
|
||||
if on_sql:
|
||||
on_sql = self.indent(on_sql, skip_first=True)
|
||||
space = self.seg(" " * self.pad) if self.pretty else " "
|
||||
|
@ -2510,13 +2517,16 @@ class Generator(metaclass=_Generator):
|
|||
)
|
||||
kind = ""
|
||||
|
||||
operation_modifiers = self.expressions(expression, key="operation_modifiers", sep=" ")
|
||||
operation_modifiers = f"{self.sep()}{operation_modifiers}" if operation_modifiers else ""
|
||||
|
||||
# We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
|
||||
# are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
|
||||
top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
|
||||
expressions = f"{self.sep()}{expressions}" if expressions else expressions
|
||||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{top_distinct}{kind}{expressions}",
|
||||
f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}",
|
||||
self.sql(expression, "into", comment=False),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
|
@ -3225,12 +3235,12 @@ class Generator(metaclass=_Generator):
|
|||
expressions = f"({expressions})" if expressions else ""
|
||||
return f"ALTER{compound} SORTKEY {this or expressions}"
|
||||
|
||||
def renametable_sql(self, expression: exp.RenameTable) -> str:
|
||||
def alterrename_sql(self, expression: exp.AlterRename) -> str:
|
||||
if not self.RENAME_TABLE_WITH_DB:
|
||||
# Remove db from tables
|
||||
expression = expression.transform(
|
||||
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
|
||||
).assert_is(exp.RenameTable)
|
||||
).assert_is(exp.AlterRename)
|
||||
this = self.sql(expression, "this")
|
||||
return f"RENAME TO {this}"
|
||||
|
||||
|
@ -3508,13 +3518,15 @@ class Generator(metaclass=_Generator):
|
|||
name = self.normalize_func(name) if normalize else name
|
||||
return f"{name}{prefix}{self.format_args(*args)}{suffix}"
|
||||
|
||||
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
|
||||
def format_args(self, *args: t.Optional[str | exp.Expression], sep: str = ", ") -> str:
|
||||
arg_sqls = tuple(
|
||||
self.sql(arg) for arg in args if arg is not None and not isinstance(arg, bool)
|
||||
)
|
||||
if self.pretty and self.too_wide(arg_sqls):
|
||||
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(arg_sqls)
|
||||
return self.indent(
|
||||
"\n" + f"{sep.strip()}\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True
|
||||
)
|
||||
return sep.join(arg_sqls)
|
||||
|
||||
def too_wide(self, args: t.Iterable) -> bool:
|
||||
return sum(len(arg) for arg in args) > self.max_text_width
|
||||
|
@ -3612,7 +3624,7 @@ class Generator(metaclass=_Generator):
|
|||
expressions = (
|
||||
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
|
||||
)
|
||||
return f"{this}{expressions}"
|
||||
return f"{this}{expressions}" if expressions.strip() != "" else this
|
||||
|
||||
def joinhint_sql(self, expression: exp.JoinHint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -4243,7 +4255,7 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
rhs = self.expressions(expression)
|
||||
|
||||
return self.func(name, expression.this, rhs)
|
||||
return self.func(name, expression.this, rhs or None)
|
||||
|
||||
def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
|
||||
if self.SUPPORTS_CONVERT_TIMEZONE:
|
||||
|
@ -4418,3 +4430,7 @@ class Generator(metaclass=_Generator):
|
|||
for_sql = f" FOR {for_sql}" if for_sql else ""
|
||||
|
||||
return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})"
|
||||
|
||||
@unsupported_args("format")
|
||||
def todouble_sql(self, expression: exp.ToDouble) -> str:
|
||||
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))
|
||||
|
|
|
@ -56,6 +56,10 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
|||
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: None) -> t.List: ...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: T) -> t.List[T]: ...
|
||||
|
||||
|
|
|
@ -287,15 +287,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType:
|
||||
) -> exp.DataType.Type:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
# We propagate the UNKNOWN type upwards if found
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.build("unknown")
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
return t.cast(
|
||||
exp.DataType.Type,
|
||||
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
|
||||
)
|
||||
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.scope import Scope, build_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ExistingCTEsMapping = t.Dict[exp.Expression, str]
|
||||
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
|
||||
|
||||
|
||||
def eliminate_subqueries(expression):
|
||||
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Rewrite derived tables as CTES, deduplicating if possible.
|
||||
|
||||
|
@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
|
|||
# Map of alias->Scope|Table
|
||||
# These are all aliases that are already used in the expression.
|
||||
# We don't want to create new CTEs that conflict with these names.
|
||||
taken = {}
|
||||
taken: TakenNameMapping = {}
|
||||
|
||||
# All CTE aliases in the root scope are taken
|
||||
for scope in root.cte_scopes:
|
||||
|
@ -56,7 +63,7 @@ def eliminate_subqueries(expression):
|
|||
|
||||
# Map of Expression->alias
|
||||
# Existing CTES in the root expression. We'll use this for deduplication.
|
||||
existing_ctes = {}
|
||||
existing_ctes: ExistingCTEsMapping = {}
|
||||
|
||||
with_ = root.expression.args.get("with")
|
||||
recursive = False
|
||||
|
@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _eliminate(scope, existing_ctes, taken):
|
||||
def _eliminate(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_cte:
|
||||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
return None
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
|
||||
def _eliminate_derived_table(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
# - eliminate a lateral correlated subquery
|
||||
|
@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
|
|||
return cte
|
||||
|
||||
|
||||
def _eliminate_cte(scope, existing_ctes, taken):
|
||||
def _eliminate_cte(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Optional[exp.Expression]:
|
||||
parent = scope.expression.parent
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
|
|||
return cte
|
||||
|
||||
|
||||
def _new_cte(scope, existing_ctes, taken):
|
||||
def _new_cte(
|
||||
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
|
||||
) -> t.Tuple[str, t.Optional[exp.Expression]]:
|
||||
"""
|
||||
Returns:
|
||||
tuple of (name, cte)
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
def merge_subqueries(expression, leave_tables_isolated=False):
|
||||
FromOrJoin = t.Union[exp.From, exp.Join]
|
||||
|
||||
|
||||
def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
"""
|
||||
Rewrite sqlglot AST to merge derived tables into the outer query.
|
||||
|
||||
|
@ -58,7 +67,7 @@ SAFE_TO_REPLACE_UNWRAPPED = (
|
|||
)
|
||||
|
||||
|
||||
def merge_ctes(expression, leave_tables_isolated=False):
|
||||
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
scopes = traverse_scope(expression)
|
||||
|
||||
# All places where we select from CTEs.
|
||||
|
@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
|
|||
return expression
|
||||
|
||||
|
||||
def merge_derived_tables(expression, leave_tables_isolated=False):
|
||||
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
|
||||
for outer_scope in traverse_scope(expression):
|
||||
for subquery in outer_scope.derived_tables:
|
||||
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
|
||||
|
@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
|
|||
return expression
|
||||
|
||||
|
||||
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
||||
def _mergeable(
|
||||
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if `inner_select` can be merged into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (Scope)
|
||||
inner_scope (Scope)
|
||||
leave_tables_isolated (bool)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
Returns:
|
||||
bool: True if can be merged
|
||||
"""
|
||||
inner_select = inner_scope.expression.unnest()
|
||||
|
||||
|
@ -195,7 +198,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not outer_scope.expression.is_star
|
||||
and isinstance(inner_select, exp.Select)
|
||||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and inner_select.args.get("from") is not None
|
||||
and not outer_scope.pivots
|
||||
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
|
@ -218,19 +221,17 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
)
|
||||
|
||||
|
||||
def _rename_inner_sources(outer_scope, inner_scope, alias):
|
||||
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
|
||||
"""
|
||||
Renames any sources in the inner query that conflict with names in the outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
alias (str)
|
||||
"""
|
||||
taken = set(outer_scope.selected_sources)
|
||||
conflicts = taken.intersection(set(inner_scope.selected_sources))
|
||||
inner_taken = set(inner_scope.selected_sources)
|
||||
outer_taken = set(outer_scope.selected_sources)
|
||||
conflicts = outer_taken.intersection(inner_taken)
|
||||
conflicts -= {alias}
|
||||
|
||||
taken = outer_taken.union(inner_taken)
|
||||
|
||||
for conflict in conflicts:
|
||||
new_name = find_new_name(taken, conflict)
|
||||
|
||||
|
@ -250,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
|
|||
inner_scope.rename_source(conflict, new_name)
|
||||
|
||||
|
||||
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
||||
def _merge_from(
|
||||
outer_scope: Scope,
|
||||
inner_scope: Scope,
|
||||
node_to_replace: t.Union[exp.Subquery, exp.Table],
|
||||
alias: str,
|
||||
) -> None:
|
||||
"""
|
||||
Merge FROM clause of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
node_to_replace (exp.Subquery|exp.Table)
|
||||
alias (str)
|
||||
"""
|
||||
new_subquery = inner_scope.expression.args["from"].this
|
||||
new_subquery.set("joins", node_to_replace.args.get("joins"))
|
||||
|
@ -274,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
|
|||
)
|
||||
|
||||
|
||||
def _merge_joins(outer_scope, inner_scope, from_or_join):
|
||||
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
|
||||
"""
|
||||
Merge JOIN clauses of inner query into outer query.
|
||||
|
||||
Args:
|
||||
outer_scope (sqlglot.optimizer.scope.Scope)
|
||||
inner_scope (sqlglot.optimizer.scope.Scope)
|
||||
from_or_join (exp.From|exp.Join)
|
||||
"""
|
||||
|
||||
new_joins = []
|
||||
|
@ -304,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
|
|||
outer_scope.expression.set("joins", outer_joins)
|
||||
|
||||
|
||||
def _merge_expressions(outer_scope, inner_scope, alias):
|
||||
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
|
||||
"""
|
||||
Merge projections of inner query into outer query.
|
||||
|
||||
|
@ -338,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
|
|||
column.replace(expression.copy())
|
||||
|
||||
|
||||
def _merge_where(outer_scope, inner_scope, from_or_join):
|
||||
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
|
||||
"""
|
||||
Merge WHERE clause of inner query into outer query.
|
||||
|
||||
|
@ -357,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
# Merge predicates from an outer join to the ON clause
|
||||
# if it only has columns that are already joined
|
||||
from_ = expression.args.get("from")
|
||||
sources = {from_.alias_or_name} if from_ else {}
|
||||
sources = {from_.alias_or_name} if from_ else set()
|
||||
|
||||
for join in expression.args["joins"]:
|
||||
source = join.alias_or_name
|
||||
|
@ -373,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
expression.where(where.this, copy=False)
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
|
||||
"""
|
||||
Merge ORDER clause of inner query into outer query.
|
||||
|
||||
|
@ -393,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
|
||||
|
||||
|
||||
def _merge_hints(outer_scope, inner_scope):
|
||||
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
|
||||
inner_scope_hint = inner_scope.expression.args.get("hint")
|
||||
if not inner_scope_hint:
|
||||
return
|
||||
|
@ -405,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
|
|||
outer_scope.expression.set("hint", inner_scope_hint)
|
||||
|
||||
|
||||
def _pop_cte(inner_scope):
|
||||
def _pop_cte(inner_scope: Scope) -> None:
|
||||
"""
|
||||
Remove CTE from the AST.
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ def qualify(
|
|||
infer_schema: t.Optional[bool] = None,
|
||||
isolate_tables: bool = False,
|
||||
qualify_columns: bool = True,
|
||||
allow_partial_qualification: bool = False,
|
||||
validate_qualify_columns: bool = True,
|
||||
quote_identifiers: bool = True,
|
||||
identify: bool = True,
|
||||
|
@ -56,6 +57,7 @@ def qualify(
|
|||
infer_schema: Whether to infer the schema if missing.
|
||||
isolate_tables: Whether to isolate table selects.
|
||||
qualify_columns: Whether to qualify columns.
|
||||
allow_partial_qualification: Whether to allow partial qualification.
|
||||
validate_qualify_columns: Whether to validate columns.
|
||||
quote_identifiers: Whether to run the quote_identifiers step.
|
||||
This step is necessary to ensure correctness for case sensitive queries.
|
||||
|
@ -90,6 +92,7 @@ def qualify(
|
|||
expand_alias_refs=expand_alias_refs,
|
||||
expand_stars=expand_stars,
|
||||
infer_schema=infer_schema,
|
||||
allow_partial_qualification=allow_partial_qualification,
|
||||
)
|
||||
|
||||
if quote_identifiers:
|
||||
|
|
|
@ -22,6 +22,7 @@ def qualify_columns(
|
|||
expand_alias_refs: bool = True,
|
||||
expand_stars: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
allow_partial_qualification: bool = False,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have fully qualified columns.
|
||||
|
@ -41,6 +42,7 @@ def qualify_columns(
|
|||
for most of the optimizer's rules to work; do not set to False unless you
|
||||
know what you're doing!
|
||||
infer_schema: Whether to infer the schema if missing.
|
||||
allow_partial_qualification: Whether to allow partial qualification.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
|
@ -68,7 +70,7 @@ def qualify_columns(
|
|||
)
|
||||
|
||||
_convert_columns_to_dots(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
|
||||
|
||||
if not schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
|
@ -240,13 +242,21 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
||||
) -> None:
|
||||
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
|
||||
is_group_by = isinstance(node, exp.Group)
|
||||
if not node or (expand_only_groupby and not is_group_by):
|
||||
return
|
||||
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
|
||||
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
|
||||
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
|
||||
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
|
||||
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
|
||||
if expand_only_groupby and is_group_by and column.parent is not node:
|
||||
continue
|
||||
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
|
||||
double_agg = (
|
||||
|
@ -273,9 +283,8 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
if simplified is not column:
|
||||
column.replace(simplified)
|
||||
|
||||
for i, projection in enumerate(scope.expression.selects):
|
||||
for i, projection in enumerate(expression.selects):
|
||||
replace_columns(projection)
|
||||
|
||||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
||||
|
||||
|
@ -434,7 +443,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
|||
scope.clear_cache()
|
||||
|
||||
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
column_table = column.table
|
||||
|
@ -442,7 +451,12 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
if column_table and column_table in scope.sources:
|
||||
source_columns = resolver.get_source_columns(column_table)
|
||||
if source_columns and column_name not in source_columns and "*" not in source_columns:
|
||||
if (
|
||||
not allow_partial_qualification
|
||||
and source_columns
|
||||
and column_name not in source_columns
|
||||
and "*" not in source_columns
|
||||
):
|
||||
raise OptimizeError(f"Unknown column: {column_name}")
|
||||
|
||||
if not column_table:
|
||||
|
@ -526,7 +540,7 @@ def _expand_stars(
|
|||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
new_selections: t.List[exp.Expression] = []
|
||||
except_columns: t.Dict[int, t.Set[str]] = {}
|
||||
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
|
||||
rename_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
|
|
|
@ -562,8 +562,8 @@ def _traverse_scope(scope):
|
|||
elif isinstance(expression, exp.DML):
|
||||
yield from _traverse_ctes(scope)
|
||||
for query in find_all_in_scope(expression, exp.Query):
|
||||
# This check ensures we don't yield the CTE queries twice
|
||||
if not isinstance(query.parent, exp.CTE):
|
||||
# This check ensures we don't yield the CTE/nested queries twice
|
||||
if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
|
||||
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
|
||||
return
|
||||
else:
|
||||
|
@ -679,6 +679,8 @@ def _traverse_tables(scope):
|
|||
expressions.extend(scope.expression.args.get("laterals") or [])
|
||||
|
||||
for expression in expressions:
|
||||
if isinstance(expression, exp.Final):
|
||||
expression = expression.this
|
||||
if isinstance(expression, exp.Table):
|
||||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
|
|
@ -206,6 +206,11 @@ COMPLEMENT_COMPARISONS = {
|
|||
exp.NEQ: exp.EQ,
|
||||
}
|
||||
|
||||
COMPLEMENT_SUBQUERY_PREDICATES = {
|
||||
exp.All: exp.Any,
|
||||
exp.Any: exp.All,
|
||||
}
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
|
@ -218,9 +223,12 @@ def simplify_not(expression):
|
|||
if is_null(this):
|
||||
return exp.null()
|
||||
if this.__class__ in COMPLEMENT_COMPARISONS:
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](
|
||||
this=this.this, expression=this.expression
|
||||
)
|
||||
right = this.expression
|
||||
complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
|
||||
if complement_subquery_predicate:
|
||||
right = complement_subquery_predicate(this=right.this)
|
||||
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
|
||||
if isinstance(this, exp.Paren):
|
||||
condition = this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
|
|
|
@ -1053,13 +1053,16 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
ALTER_PARSERS = {
|
||||
"ADD": lambda self: self._parse_alter_table_add(),
|
||||
"AS": lambda self: self._parse_select(),
|
||||
"ALTER": lambda self: self._parse_alter_table_alter(),
|
||||
"CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
|
||||
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
|
||||
"DROP": lambda self: self._parse_alter_table_drop(),
|
||||
"RENAME": lambda self: self._parse_alter_table_rename(),
|
||||
"SET": lambda self: self._parse_alter_table_set(),
|
||||
"AS": lambda self: self._parse_select(),
|
||||
"SWAP": lambda self: self.expression(
|
||||
exp.SwapTable, this=self._match(TokenType.WITH) and self._parse_table(schema=True)
|
||||
),
|
||||
}
|
||||
|
||||
ALTER_ALTER_PARSERS = {
|
||||
|
@ -1222,6 +1225,10 @@ class Parser(metaclass=_Parser):
|
|||
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
|
||||
}
|
||||
|
||||
PROCEDURE_OPTIONS: OPTIONS_TYPE = {}
|
||||
|
||||
EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple())
|
||||
|
||||
KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
|
||||
"NOT": ("ENFORCED",),
|
||||
"MATCH": (
|
||||
|
@ -1286,6 +1293,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN}
|
||||
|
||||
# The style options for the DESCRIBE statement
|
||||
DESCRIBE_STYLES = {"ANALYZE", "EXTENDED", "FORMATTED", "HISTORY"}
|
||||
|
||||
OPERATION_MODIFIERS: t.Set[str] = set()
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
PREFIXED_PIVOT_COLUMNS = False
|
||||
|
@ -2195,11 +2207,26 @@ class Parser(metaclass=_Parser):
|
|||
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
|
||||
)
|
||||
|
||||
if self._match_texts(self.PROCEDURE_OPTIONS, advance=False):
|
||||
return self.expression(
|
||||
exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option)
|
||||
)
|
||||
|
||||
if not self._next:
|
||||
return None
|
||||
|
||||
return self._parse_withisolatedloading()
|
||||
|
||||
def _parse_procedure_option(self) -> exp.Expression | None:
|
||||
if self._match_text_seq("EXECUTE", "AS"):
|
||||
return self.expression(
|
||||
exp.ExecuteAsProperty,
|
||||
this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False)
|
||||
or self._parse_string(),
|
||||
)
|
||||
|
||||
return self._parse_var_from_options(self.PROCEDURE_OPTIONS)
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
|
||||
def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
|
||||
self._match(TokenType.EQ)
|
||||
|
@ -2567,7 +2594,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_describe(self) -> exp.Describe:
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper()
|
||||
style = self._match_texts(self.DESCRIBE_STYLES) and self._prev.text.upper()
|
||||
if self._match(TokenType.DOT):
|
||||
style = None
|
||||
self._retreat(self._index - 2)
|
||||
|
@ -2955,6 +2982,10 @@ class Parser(metaclass=_Parser):
|
|||
if all_ and distinct:
|
||||
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
|
||||
|
||||
operation_modifiers = []
|
||||
while self._curr and self._match_texts(self.OPERATION_MODIFIERS):
|
||||
operation_modifiers.append(exp.var(self._prev.text.upper()))
|
||||
|
||||
limit = self._parse_limit(top=True)
|
||||
projections = self._parse_projections()
|
||||
|
||||
|
@ -2965,6 +2996,7 @@ class Parser(metaclass=_Parser):
|
|||
distinct=distinct,
|
||||
expressions=projections,
|
||||
limit=limit,
|
||||
operation_modifiers=operation_modifiers or None,
|
||||
)
|
||||
this.comments = comments
|
||||
|
||||
|
@ -3400,6 +3432,10 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
|
||||
if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA):
|
||||
kwargs["expressions"] = self._parse_csv(
|
||||
lambda: self._parse_table(parse_bracket=parse_bracket)
|
||||
)
|
||||
|
||||
if method:
|
||||
kwargs["method"] = method.text
|
||||
|
@ -3420,7 +3456,7 @@ class Parser(metaclass=_Parser):
|
|||
elif (
|
||||
not (outer_apply or cross_apply)
|
||||
and not isinstance(kwargs["this"], exp.Unnest)
|
||||
and not (kind and kind.token_type == TokenType.CROSS)
|
||||
and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY))
|
||||
):
|
||||
index = self._index
|
||||
joins: t.Optional[list] = list(self._parse_joins())
|
||||
|
@ -4470,7 +4506,7 @@ class Parser(metaclass=_Parser):
|
|||
elif not self._match(TokenType.R_BRACKET, expression=this):
|
||||
self.raise_error("Expecting ]")
|
||||
else:
|
||||
this = self.expression(exp.In, this=this, field=self._parse_field())
|
||||
this = self.expression(exp.In, this=this, field=self._parse_column())
|
||||
|
||||
return this
|
||||
|
||||
|
@ -5533,12 +5569,15 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.CONSTRAINT):
|
||||
this = self._parse_id_var()
|
||||
else:
|
||||
this = None
|
||||
this = self._match(TokenType.CONSTRAINT) and self._parse_id_var()
|
||||
|
||||
if self._match_texts(self.CONSTRAINT_PARSERS):
|
||||
procedure_option_follows = (
|
||||
self._match(TokenType.WITH, advance=False)
|
||||
and self._next
|
||||
and self._next.text.upper() in self.PROCEDURE_OPTIONS
|
||||
)
|
||||
|
||||
if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS):
|
||||
return self.expression(
|
||||
exp.ColumnConstraint,
|
||||
this=this,
|
||||
|
@ -6764,7 +6803,7 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(index)
|
||||
return self._parse_csv(self._parse_drop_column)
|
||||
|
||||
def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
|
||||
def _parse_alter_table_rename(self) -> t.Optional[exp.AlterRename | exp.RenameColumn]:
|
||||
if self._match(TokenType.COLUMN):
|
||||
exists = self._parse_exists()
|
||||
old_column = self._parse_column()
|
||||
|
@ -6777,7 +6816,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
|
||||
|
||||
self._match_text_seq("TO")
|
||||
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
|
||||
return self.expression(exp.AlterRename, this=self._parse_table(schema=True))
|
||||
|
||||
def _parse_alter_table_set(self) -> exp.AlterSet:
|
||||
alter_set = self.expression(exp.AlterSet)
|
||||
|
|
|
@ -107,6 +107,7 @@ LANGUAGE js AS
|
|||
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
|
||||
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
|
||||
|
||||
self.validate_identity("SELECT ARRAY_CONCAT([1])")
|
||||
self.validate_identity("SELECT * FROM READ_CSV('bla.csv')")
|
||||
self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
|
||||
self.validate_identity("assert.true(1 = 1)")
|
||||
|
|
|
@ -2,6 +2,7 @@ from datetime import date
|
|||
from sqlglot import exp, parse_one
|
||||
from sqlglot.dialects import ClickHouse
|
||||
from sqlglot.expressions import convert
|
||||
from sqlglot.optimizer import traverse_scope
|
||||
from tests.dialects.test_dialect import Validator
|
||||
from sqlglot.errors import ErrorLevel
|
||||
|
||||
|
@ -28,6 +29,7 @@ class TestClickhouse(Validator):
|
|||
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
|
||||
self.assertIsNone(expr._meta)
|
||||
|
||||
self.validate_identity("CAST(1 AS Bool)")
|
||||
self.validate_identity("SELECT toString(CHAR(104.1, 101, 108.9, 108.9, 111, 32))")
|
||||
self.validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var)
|
||||
self.validate_identity("SELECT toFloat(like)")
|
||||
|
@ -420,11 +422,6 @@ class TestClickhouse(Validator):
|
|||
" GROUP BY loyalty ORDER BY loyalty ASC"
|
||||
},
|
||||
)
|
||||
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
|
||||
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
|
||||
self.validate_identity(
|
||||
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
|
||||
)
|
||||
self.validate_all(
|
||||
"SELECT quantile(0.5)(a)",
|
||||
read={"duckdb": "SELECT quantile(a, 0.5)"},
|
||||
|
@ -1100,3 +1097,36 @@ LIFETIME(MIN 0 MAX 0)""",
|
|||
def test_grant(self):
|
||||
self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION")
|
||||
self.validate_identity("GRANT INSERT(x, y) ON db.table TO john")
|
||||
|
||||
def test_array_join(self):
|
||||
expr = self.validate_identity(
|
||||
"SELECT * FROM arrays_test ARRAY JOIN arr1, arrays_test.arr2 AS foo, ['a', 'b', 'c'] AS elem"
|
||||
)
|
||||
joins = expr.args["joins"]
|
||||
self.assertEqual(len(joins), 1)
|
||||
|
||||
join = joins[0]
|
||||
self.assertEqual(join.kind, "ARRAY")
|
||||
self.assertIsInstance(join.this, exp.Column)
|
||||
|
||||
self.assertEqual(len(join.expressions), 2)
|
||||
self.assertIsInstance(join.expressions[0], exp.Alias)
|
||||
self.assertIsInstance(join.expressions[0].this, exp.Column)
|
||||
|
||||
self.assertIsInstance(join.expressions[1], exp.Alias)
|
||||
self.assertIsInstance(join.expressions[1].this, exp.Array)
|
||||
|
||||
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
|
||||
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
|
||||
self.validate_identity(
|
||||
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
|
||||
)
|
||||
self.validate_identity(
|
||||
"SELECT * FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external1, ['a', 'b', 'c'] AS arr_external2, splitByString(',', 'asd,qwerty,zxc') AS arr_external3"
|
||||
)
|
||||
|
||||
def test_traverse_scope(self):
|
||||
sql = "SELECT * FROM t FINAL"
|
||||
scopes = traverse_scope(parse_one(sql, dialect=self.dialect))
|
||||
self.assertEqual(len(scopes), 1)
|
||||
self.assertEqual(set(scopes[0].sources), {"t"})
|
||||
|
|
|
@ -7,6 +7,7 @@ class TestDatabricks(Validator):
|
|||
dialect = "databricks"
|
||||
|
||||
def test_databricks(self):
|
||||
self.validate_identity("SELECT t.current_time FROM t")
|
||||
self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT")
|
||||
self.validate_identity("DESCRIBE HISTORY a.b")
|
||||
self.validate_identity("DESCRIBE history.tbl")
|
||||
|
|
|
@ -1762,6 +1762,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"LEVENSHTEIN(col1, col2)",
|
||||
write={
|
||||
"bigquery": "EDIT_DISTANCE(col1, col2)",
|
||||
"duckdb": "LEVENSHTEIN(col1, col2)",
|
||||
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
|
||||
|
@ -1772,6 +1773,7 @@ class TestDialect(Validator):
|
|||
self.validate_all(
|
||||
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
|
||||
write={
|
||||
"bigquery": "EDIT_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
|
||||
|
|
|
@ -256,6 +256,9 @@ class TestDuckDB(Validator):
|
|||
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
|
||||
)
|
||||
|
||||
self.validate_identity("SELECT UNNEST([1, 2])").selects[0].assert_is(exp.UDTF)
|
||||
self.validate_identity("'red' IN flags").args["field"].assert_is(exp.Column)
|
||||
self.validate_identity("'red' IN tbl.flags")
|
||||
self.validate_identity("CREATE TABLE tbl1 (u UNION(num INT, str TEXT))")
|
||||
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
|
||||
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue