1
0
Fork 0

Merging upstream version 25.26.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:56:02 +01:00
parent 9138e4b92a
commit 829a709061
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
117 changed files with 49296 additions and 47316 deletions

View file

@ -1,6 +1,72 @@
Changelog
=========
## [v25.25.1] - 2024-10-15
### :bug: Bug Fixes
- [`e6567ae`](https://github.com/tobymao/sqlglot/commit/e6567ae11650834874808a844a19836fbb9ee753) - small overload fix for ensure list taking None *(PR [#4248](https://github.com/tobymao/sqlglot/pull/4248) by [@benfdking](https://github.com/benfdking))*
## [v25.25.0] - 2024-10-14
### :boom: BREAKING CHANGES
- due to [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*:
refactor set operation builders so they can work with N expressions (#4226)
- due to [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*:
Native annotations for string functions (#4231)
- due to [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*:
Native annotations for string functions (#4234)
- due to [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*:
produce exp.Column for the RHS of <value> IN <name> (#4239)
- due to [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*:
make Explode a UDTF subclass (#4242)
### :sparkles: New Features
- [`163e943`](https://github.com/tobymao/sqlglot/commit/163e943cdaf449599640c198f69e73d2398eb323) - **tsql**: SPLIT_PART function and conversion to PARSENAME in tsql *(PR [#4211](https://github.com/tobymao/sqlglot/pull/4211) by [@daihuynh](https://github.com/daihuynh))*
- [`275b64b`](https://github.com/tobymao/sqlglot/commit/275b64b6a28722232a24870e443b249994220d54) - refactor set operation builders so they can work with N expressions *(PR [#4226](https://github.com/tobymao/sqlglot/pull/4226) by [@georgesittas](https://github.com/georgesittas))*
- [`3f6ba3e`](https://github.com/tobymao/sqlglot/commit/3f6ba3e69c9ba92429d2b3b00cac33f45518aa56) - **clickhouse**: Support varlen arrays for ARRAY JOIN *(PR [#4229](https://github.com/tobymao/sqlglot/pull/4229) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *addresses issue [#4227](https://github.com/tobymao/sqlglot/issues/4227) opened by [@brunorpinho](https://github.com/brunorpinho)*
- [`aee76da`](https://github.com/tobymao/sqlglot/commit/aee76da1cadec242f7428d23999f1752cb0708ca) - **bigquery**: Native annotations for string functions *(PR [#4231](https://github.com/tobymao/sqlglot/pull/4231) by [@VaggelisD](https://github.com/VaggelisD))*
- [`202aaa0`](https://github.com/tobymao/sqlglot/commit/202aaa0e7390142ee3ade41c28e2e77cde31f295) - **bigquery**: Native annotations for string functions *(PR [#4234](https://github.com/tobymao/sqlglot/pull/4234) by [@VaggelisD](https://github.com/VaggelisD))*
- [`eeae25e`](https://github.com/tobymao/sqlglot/commit/eeae25e03a883671f9d5e514f9bd3021fb6c0d32) - support EXPLAIN in mysql *(PR [#4235](https://github.com/tobymao/sqlglot/pull/4235) by [@xiaoyu-meng-mxy](https://github.com/xiaoyu-meng-mxy))*
- [`06748d9`](https://github.com/tobymao/sqlglot/commit/06748d93ccd232528003c37fdda25ae8163f3c18) - **mysql**: add support for operation modifiers like HIGH_PRIORITY *(PR [#4238](https://github.com/tobymao/sqlglot/pull/4238) by [@georgesittas](https://github.com/georgesittas))*
- :arrow_lower_right: *addresses issue [#4236](https://github.com/tobymao/sqlglot/issues/4236) opened by [@asdfsx](https://github.com/asdfsx)*
### :bug: Bug Fixes
- [`dcdec95`](https://github.com/tobymao/sqlglot/commit/dcdec95f986426ae90469baca993b47ac390081b) - Make exp.Update a DML node *(PR [#4223](https://github.com/tobymao/sqlglot/pull/4223) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *fixes issue [#4221](https://github.com/tobymao/sqlglot/issues/4221) opened by [@rahul-ve](https://github.com/rahul-ve)*
- [`79caf51`](https://github.com/tobymao/sqlglot/commit/79caf519987718390a086bee19fdc89f6094496c) - **clickhouse**: rename BOOLEAN type to Bool fixes [#4230](https://github.com/tobymao/sqlglot/pull/4230) *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`b26a3f6`](https://github.com/tobymao/sqlglot/commit/b26a3f67b7113802ba1b4b3b211431e98258dc15) - satisfy mypy *(commit by [@georgesittas](https://github.com/georgesittas))*
- [`5741180`](https://github.com/tobymao/sqlglot/commit/5741180e895eaaa75a07af388d36a0d2df97b28c) - **parser**: produce exp.Column for the RHS of <value> IN <name> *(PR [#4239](https://github.com/tobymao/sqlglot/pull/4239) by [@georgesittas](https://github.com/georgesittas))*
- :arrow_lower_right: *fixes issue [#4237](https://github.com/tobymao/sqlglot/issues/4237) opened by [@rustyconover](https://github.com/rustyconover)*
- [`daa6e78`](https://github.com/tobymao/sqlglot/commit/daa6e78e4b810eff826f995aa52f9e38197f1b7e) - **optimizer**: handle subquery predicate substitution correctly in de morgan's rule *(PR [#4240](https://github.com/tobymao/sqlglot/pull/4240) by [@georgesittas](https://github.com/georgesittas))*
- [`c0a8355`](https://github.com/tobymao/sqlglot/commit/c0a83556acffcd77521f69bf51503a07310f749d) - **parser**: parse a column reference for the RHS of the IN clause *(PR [#4241](https://github.com/tobymao/sqlglot/pull/4241) by [@georgesittas](https://github.com/georgesittas))*
### :recycle: Refactors
- [`0882f03`](https://github.com/tobymao/sqlglot/commit/0882f03d526f593b2d415e85b7d7a7c113721806) - Rename exp.RenameTable to exp.AlterRename *(PR [#4224](https://github.com/tobymao/sqlglot/pull/4224) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *addresses issue [#4222](https://github.com/tobymao/sqlglot/issues/4222) opened by [@s1101010110](https://github.com/s1101010110)*
- [`fd42b5c`](https://github.com/tobymao/sqlglot/commit/fd42b5cdaf9421abb11e71d82726536af09e3ae3) - Simplify PARSENAME <-> SPLIT_PART transpilation *(PR [#4225](https://github.com/tobymao/sqlglot/pull/4225) by [@VaggelisD](https://github.com/VaggelisD))*
- [`4da2502`](https://github.com/tobymao/sqlglot/commit/4da25029b1c6f1425b4602f42da4fa1bcd3fccdb) - make Explode a UDTF subclass *(PR [#4242](https://github.com/tobymao/sqlglot/pull/4242) by [@georgesittas](https://github.com/georgesittas))*
## [v25.24.5] - 2024-10-08
### :sparkles: New Features
- [`22a1684`](https://github.com/tobymao/sqlglot/commit/22a16848d80a2fa6d310f99d21f7d81f90eb9440) - **bigquery**: Native annotations for more math functions *(PR [#4212](https://github.com/tobymao/sqlglot/pull/4212) by [@VaggelisD](https://github.com/VaggelisD))*
- [`354cfff`](https://github.com/tobymao/sqlglot/commit/354cfff13ab30d01c6123fca74eed0669d238aa0) - add builder methods to exp.Update and add with_ arg to exp.update *(PR [#4217](https://github.com/tobymao/sqlglot/pull/4217) by [@brdbry](https://github.com/brdbry))*
### :bug: Bug Fixes
- [`2c513b7`](https://github.com/tobymao/sqlglot/commit/2c513b71c7d4b1ff5c7c4e12d6c38694210b1a12) - Attach CTE comments before commas *(PR [#4218](https://github.com/tobymao/sqlglot/pull/4218) by [@VaggelisD](https://github.com/VaggelisD))*
- :arrow_lower_right: *fixes issue [#4216](https://github.com/tobymao/sqlglot/issues/4216) opened by [@ajfriend](https://github.com/ajfriend)*
## [v25.24.4] - 2024-10-04
### :bug: Bug Fixes
- [`484df7d`](https://github.com/tobymao/sqlglot/commit/484df7d50df5cb314943e1810db18a7d7d5bb3eb) - tsql union with limit *(commit by [@tobymao](https://github.com/tobymao))*
@ -4969,3 +5035,6 @@ Changelog
[v25.24.2]: https://github.com/tobymao/sqlglot/compare/v25.24.1...v25.24.2
[v25.24.3]: https://github.com/tobymao/sqlglot/compare/v25.24.2...v25.24.3
[v25.24.4]: https://github.com/tobymao/sqlglot/compare/v25.24.3...v25.24.4
[v25.24.5]: https://github.com/tobymao/sqlglot/compare/v25.24.4...v25.24.5
[v25.25.0]: https://github.com/tobymao/sqlglot/compare/v25.24.5...v25.25.0
[v25.25.1]: https://github.com/tobymao/sqlglot/compare/v25.25.0...v25.25.1

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -321,7 +321,24 @@ class BigQuery(Dialect):
expr_type: lambda self, e: _annotate_math_functions(self, e)
for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round)
},
**{
expr_type: lambda self, e: self._annotate_by_args(e, "this")
for expr_type in (
exp.Left,
exp.Right,
exp.Lower,
exp.Upper,
exp.Pad,
exp.Trim,
exp.RegexpExtract,
exp.RegexpReplace,
exp.Repeat,
exp.Substring,
)
},
exp.Concat: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Sign: lambda self, e: self._annotate_by_args(e, "this"),
exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True),
}
def normalize_identifier(self, expression: E) -> E:
@ -716,6 +733,7 @@ class BigQuery(Dialect):
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Levenshtein: rename_func("EDIT_DISTANCE"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),

View file

@ -603,6 +603,12 @@ class ClickHouse(Dialect):
if join:
join.set("global", join.args.pop("method", None))
# tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table`
# https://clickhouse.com/docs/en/sql-reference/statements/select/array-join
if join.kind == "ARRAY":
for table in join.find_all(exp.Table):
table.replace(table.to_column())
return join
def _parse_function(
@ -627,15 +633,18 @@ class ClickHouse(Dialect):
)
if parts:
params = self._parse_func_params(func)
anon_func: exp.Anonymous = t.cast(exp.Anonymous, func)
params = self._parse_func_params(anon_func)
kwargs = {
"this": func.this,
"expressions": func.expressions,
"this": anon_func.this,
"expressions": anon_func.expressions,
}
if parts[1]:
kwargs["parts"] = parts
exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
exp_class: t.Type[exp.Expression] = (
exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
)
else:
exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc
@ -825,6 +834,7 @@ class ClickHouse(Dialect):
**generator.Generator.TYPE_MAPPING,
**STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BOOLEAN: "Bool",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATE32: "Date32",
exp.DataType.Type.DATETIME: "DateTime",

View file

@ -588,6 +588,7 @@ class Dialect(metaclass=_Dialect):
exp.Stddev,
exp.StddevPop,
exp.StddevSamp,
exp.ToDouble,
exp.Variance,
exp.VariancePop,
},
@ -1697,3 +1698,18 @@ def build_regexp_extract(args: t.List, dialect: Dialect) -> exp.RegexpExtract:
expression=seq_get(args, 1),
group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP),
)
def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)

View file

@ -35,6 +35,7 @@ from sqlglot.dialects.dialect import (
unit_to_str,
sha256_sql,
build_regexp_extract,
explode_to_unnest_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -538,6 +539,7 @@ class DuckDB(Dialect):
exp.JSONExtract: _arrow_json_extract_sql,
exp.JSONExtractScalar: _arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.Lateral: explode_to_unnest_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),

View file

@ -333,6 +333,9 @@ class Hive(Dialect):
"TRANSFORM": lambda self: self._parse_transform(),
}
NO_PAREN_FUNCTIONS = parser.Parser.NO_PAREN_FUNCTIONS.copy()
NO_PAREN_FUNCTIONS.pop(TokenType.CURRENT_TIME)
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"SERDEPROPERTIES": lambda self: exp.SerdeProperties(

View file

@ -187,6 +187,9 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CHARSET": TokenType.CHARACTER_SET,
# The DESCRIBE and EXPLAIN statements are synonyms.
# https://dev.mysql.com/doc/refman/8.4/en/explain.html
"EXPLAIN": TokenType.DESCRIBE,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"KEY": TokenType.KEY,
@ -453,6 +456,17 @@ class MySQL(Dialect):
TokenType.SET,
}
# SELECT [ ALL | DISTINCT | DISTINCTROW ] [ <OPERATION_MODIFIERS> ]
OPERATION_MODIFIERS = {
"HIGH_PRIORITY",
"STRAIGHT_JOIN",
"SQL_SMALL_RESULT",
"SQL_BIG_RESULT",
"SQL_BUFFER_RESULT",
"SQL_NO_CACHE",
"SQL_CALC_FOUND_ROWS",
}
LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True
VALUES_FOLLOWED_BY_PAREN = False

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import OPTIONS_TYPE, build_coalesce
from sqlglot.tokens import TokenType
from sqlglot.errors import ParseError
if t.TYPE_CHECKING:
from sqlglot._typing import E
@ -205,6 +206,57 @@ class Oracle(Dialect):
)
def _parse_hint(self) -> t.Optional[exp.Hint]:
start_index = self._index
should_fallback_to_string = False
if not self._match(TokenType.HINT):
return None
hints = []
try:
for hint in iter(
lambda: self._parse_csv(
lambda: self._parse_hint_function_call() or self._parse_var(upper=True),
),
[],
):
hints.extend(hint)
except ParseError:
should_fallback_to_string = True
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
should_fallback_to_string = True
if should_fallback_to_string:
self._retreat(start_index)
return self._parse_hint_fallback_to_string()
return self.expression(exp.Hint, expressions=hints)
def _parse_hint_function_call(self) -> t.Optional[exp.Expression]:
if not self._curr or not self._next or self._next.token_type != TokenType.L_PAREN:
return None
this = self._curr.text
self._advance(2)
args = self._parse_hint_args()
this = self.expression(exp.Anonymous, this=this, expressions=args)
self._match_r_paren(this)
return this
def _parse_hint_args(self):
args = []
result = self._parse_var()
while result:
args.append(result)
result = self._parse_var()
return args
def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
start = self._curr
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
@ -271,6 +323,7 @@ class Oracle(Dialect):
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True
TZ_TO_WITH_TIME_ZONE = True
QUERY_HINT_SEP = " "
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -370,3 +423,15 @@ class Oracle(Dialect):
return f"{self.seg(into)} {self.sql(expression, 'this')}"
return f"{self.seg(into)} {self.expressions(expression)}"
def hint_sql(self, expression: exp.Hint) -> str:
expressions = []
for expression in expression.expressions:
if isinstance(expression, exp.Anonymous):
formatted_args = self.format_args(*expression.expressions, sep=" ")
expressions.append(f"{self.sql(expression, 'this')}({formatted_args})")
else:
expressions.append(self.sql(expression))
return f" /*+ {self.expressions(sqls=expressions, sep=self.QUERY_HINT_SEP).strip()} */"

View file

@ -30,6 +30,7 @@ from sqlglot.dialects.dialect import (
unit_to_str,
sequence_sql,
build_regexp_extract,
explode_to_unnest_sql,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
@ -40,21 +41,6 @@ from sqlglot.transforms import unqualify_columns
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub]
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
exp.Join(
this=exp.Unnest(
expressions=[expression.this.this],
alias=expression.args.get("alias"),
offset=isinstance(expression.this, exp.Posexplode),
),
kind="cross",
)
)
return self.lateral_sql(expression)
def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
@ -340,16 +326,17 @@ class Presto(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
exp.DataType.Type.BIT: "BOOLEAN",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.DATETIME64: "TIMESTAMP",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.STRUCT: "ROW",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TIMETZ: "TIME",
}
TRANSFORMS = {
@ -400,9 +387,6 @@ class Presto(Dialect):
exp.GenerateSeries: sequence_sql,
exp.GenerateDateArray: sequence_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@ -410,7 +394,7 @@ class Presto(Dialect):
exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Lateral: explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"),
@ -694,3 +678,10 @@ class Presto(Dialect):
expr = "".join(segments)
return f"{this}{expr}"
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
return self.func(
"ARRAY_JOIN",
self.func("ARRAY_AGG", expression.this),
expression.args.get("separator"),
)

View file

@ -41,15 +41,23 @@ def _build_datetime(
if isinstance(value, exp.Literal):
# Converts calls like `TO_TIME('01:02:03')` into casts
if len(args) == 1 and value.is_string and not int_value:
return exp.cast(value, kind)
return (
exp.TryCast(this=value, to=exp.DataType.build(kind))
if safe
else exp.cast(value, kind)
)
# Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
# cases so we can transpile them, since they're relatively common
if kind == exp.DataType.Type.TIMESTAMP:
if int_value:
if int_value and not safe:
# TRY_TO_TIMESTAMP('integer') is not parsed into exp.UnixToTime as
# it's not easily transpilable
return exp.UnixToTime(this=value, scale=seq_get(args, 1))
if not is_float(value.this):
return build_formatted_time(exp.StrToTime, "snowflake")(args)
expr = build_formatted_time(exp.StrToTime, "snowflake")(args)
expr.set("safe", safe)
return expr
if kind == exp.DataType.Type.DATE and not int_value:
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
@ -345,6 +353,9 @@ class Snowflake(Dialect):
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TRY_TO_TIMESTAMP": _build_datetime(
"TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
this=seq_get(args, 0),
@ -384,7 +395,6 @@ class Snowflake(Dialect):
expressions=self._parse_csv(self._parse_id_var),
unset=True,
),
"SWAP": lambda self: self._parse_alter_table_swap(),
}
STATEMENT_PARSERS = {
@ -654,10 +664,6 @@ class Snowflake(Dialect):
},
)
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
def _parse_location_property(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
return self.expression(exp.LocationProperty, this=self._parse_location_path())
@ -828,7 +834,6 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Stuff: rename_func("INSERT"),
exp.TimeAdd: date_delta_sql("TIMEADD"),
exp.TimestampDiff: lambda self, e: self.func(
@ -842,6 +847,7 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToArray: rename_func("TO_ARRAY"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToDouble: rename_func("TO_DOUBLE"),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: lambda self, e: self.func(
@ -1036,10 +1042,6 @@ class Snowflake(Dialect):
increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}"
def swaptable_sql(self, expression: exp.SwapTable) -> str:
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
def cluster_sql(self, expression: exp.Cluster) -> str:
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
@ -1074,3 +1076,9 @@ class Snowflake(Dialect):
tag = f" TAG {tag}" if tag else ""
return f"SET{exprs}{file_format}{copy_options}{tag}"
def strtotime_sql(self, expression: exp.StrToTime):
safe_prefix = "TRY_" if expression.args.get("safe") else ""
return self.func(
f"{safe_prefix}TO_TIMESTAMP", expression.this, self.format_time(expression)
)

View file

@ -10,11 +10,15 @@ class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
LOG_BASE_FIRST = True
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
class Parser(Presto.Parser):
FUNCTION_PARSERS = {
**Presto.Parser.FUNCTION_PARSERS,
"TRIM": lambda self: self._parse_trim(),
"JSON_QUERY": lambda self: self._parse_json_query(),
"LISTAGG": lambda self: self._parse_string_agg(),
}
JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = {
@ -65,5 +69,14 @@ class Trino(Presto):
return self.func("JSON_QUERY", expression.this, json_path + option)
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
this = expression.this
separator = expression.args.get("separator") or exp.Literal.string(",")
if isinstance(this, exp.Order):
if this.this:
this = this.this.pop()
return f"LISTAGG({self.format_args(this, separator)}) WITHIN GROUP ({self.sql(expression.this).lstrip()})"
return super().groupconcat_sql(expression)

View file

@ -324,6 +324,25 @@ def _build_with_arg_as_text(
return _parse
# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16
def _build_parsename(args: t.List) -> exp.SplitPart | exp.Anonymous:
# PARSENAME(...) will be stored into exp.SplitPart if:
# - All args are literals
# - The part index (2nd arg) is <= 4 (max valid value, otherwise TSQL returns NULL)
if len(args) == 2 and all(isinstance(arg, exp.Literal) for arg in args):
this = args[0]
part_index = args[1]
split_count = len(this.name.split("."))
if split_count <= 4:
return exp.SplitPart(
this=this,
delimiter=exp.Literal.string("."),
part_index=exp.Literal.number(split_count + 1 - part_index.to_py()),
)
return exp.Anonymous(this="PARSENAME", expressions=args)
def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract:
if len(args) == 1:
# The default value for path is '$'. As a result, if you don't provide a
@ -543,6 +562,7 @@ class TSQL(Dialect):
"LEN": _build_with_arg_as_text(exp.Length),
"LEFT": _build_with_arg_as_text(exp.Left),
"RIGHT": _build_with_arg_as_text(exp.Right),
"PARSENAME": _build_parsename,
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@ -554,6 +574,10 @@ class TSQL(Dialect):
JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
PROCEDURE_OPTIONS = dict.fromkeys(
("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple()
)
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
*parser.Parser.TYPE_TOKENS,
@ -699,7 +723,11 @@ class TSQL(Dialect):
):
return this
if not self._match(TokenType.WITH, advance=False):
expressions = self._parse_csv(self._parse_function_parameter)
else:
expressions = None
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_id_var(
@ -954,6 +982,27 @@ class TSQL(Dialect):
self.unsupported("LATERAL clause is not supported.")
return "LATERAL"
def splitpart_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str:
this = expression.this
split_count = len(this.name.split("."))
delimiter = expression.args.get("delimiter")
part_index = expression.args.get("part_index")
if (
not all(isinstance(arg, exp.Literal) for arg in (this, delimiter, part_index))
or (delimiter and delimiter.name != ".")
or not part_index
or split_count > 4
):
self.unsupported(
"SPLIT_PART can be transpiled to PARSENAME only for '.' delimiter and literal values"
)
return ""
return self.func(
"PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py())
)
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
@ -1166,7 +1215,7 @@ class TSQL(Dialect):
def alter_sql(self, expression: exp.Alter) -> str:
action = seq_get(expression.args.get("actions") or [], 0)
if isinstance(action, exp.RenameTable):
if isinstance(action, exp.AlterRename):
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
return super().alter_sql(expression)

View file

@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp
from sqlglot.helper import ensure_list
from sqlglot.helper import seq_get
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@ -185,7 +185,7 @@ class ChangeDistiller:
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set() | {(s, t) for s, t in pre_matched_nodes.items()}
matching_set = self._compute_matching_set() | set(pre_matched_nodes.items())
return self._generate_edit_script(matching_set, delta_only)
def _generate_edit_script(
@ -201,6 +201,7 @@ class ChangeDistiller:
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if (
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
or source_node == target_node
@ -208,7 +209,13 @@ class ChangeDistiller:
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
if not delta_only:
source_non_expression_leaves = dict(_get_non_expression_leaves(source_node))
target_non_expression_leaves = dict(_get_non_expression_leaves(target_node))
if source_non_expression_leaves != target_non_expression_leaves:
edit_script.append(Update(source_node, target_node))
elif not delta_only:
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@ -246,8 +253,8 @@ class ChangeDistiller:
source_node = self._source_index[source_node_id]
target_node = self._target_index[target_node_id]
if _is_same_type(source_node, target_node):
source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
source_leaf_ids = {id(l) for l in _get_expression_leaves(source_node)}
target_leaf_ids = {id(l) for l in _get_expression_leaves(target_node)}
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
@ -277,10 +284,10 @@ class ChangeDistiller:
def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
for target_leaf in target_leaves:
source_expression_leaves = list(_get_expression_leaves(self._source))
target_expression_leaves = list(_get_expression_leaves(self._target))
for source_leaf in source_expression_leaves:
for target_leaf in target_expression_leaves:
if _is_same_type(source_leaf, target_leaf):
similarity_score = self._dice_coefficient(source_leaf, target_leaf)
if similarity_score >= self.f:
@ -338,18 +345,28 @@ class ChangeDistiller:
return bigram_histo
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
for node in expression.iter_expressions():
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
has_child_exprs = True
yield from _get_leaves(node)
yield from _get_expression_leaves(node)
if not has_child_exprs:
yield expression
def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]:
for arg, value in expression.args.items():
if isinstance(value, exp.Expression) or (
isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression)
):
continue
yield (arg, value)
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if isinstance(source, exp.Join):
@ -372,16 +389,12 @@ def _parent_similarity_score(
return 1 + _parent_similarity_score(source.parent, target.parent)
def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
return [
a
for a in args
if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
]
def _expression_only_args(expression: exp.Expression) -> t.Iterator[exp.Expression]:
yield from (
arg
for arg in expression.iter_expressions()
if not isinstance(arg, IGNORED_LEAF_EXPRESSION_TYPES)
)
def _lcs(

View file

@ -404,9 +404,9 @@ class Expression(metaclass=_Expression):
def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]:
"""Yields the key and expression for all arguments, exploding list args."""
# remove tuple when python 3.7 is deprecated
for vs in reversed(tuple(self.args.values())) if reverse else self.args.values():
for vs in reversed(tuple(self.args.values())) if reverse else self.args.values(): # type: ignore
if type(vs) is list:
for v in reversed(vs) if reverse else vs:
for v in reversed(vs) if reverse else vs: # type: ignore
if hasattr(v, "parent"):
yield v
else:
@ -1247,7 +1247,7 @@ class Query(Expression):
)
def union(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Union:
"""
Builds a UNION expression.
@ -1258,8 +1258,8 @@ class Query(Expression):
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings.
If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions.
@ -1267,10 +1267,10 @@ class Query(Expression):
Returns:
The new Union expression.
"""
return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
return union(self, *expressions, distinct=distinct, dialect=dialect, **opts)
def intersect(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Intersect:
"""
Builds an INTERSECT expression.
@ -1281,8 +1281,8 @@ class Query(Expression):
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings.
If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions.
@ -1290,10 +1290,10 @@ class Query(Expression):
Returns:
The new Intersect expression.
"""
return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
return intersect(self, *expressions, distinct=distinct, dialect=dialect, **opts)
def except_(
self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
) -> Except:
"""
Builds an EXCEPT expression.
@ -1304,8 +1304,8 @@ class Query(Expression):
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
expression: the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings.
If `Expression` instance are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
opts: other options to use to parse the input expressions.
@ -1313,7 +1313,7 @@ class Query(Expression):
Returns:
The new Except expression.
"""
return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts)
return except_(self, *expressions, distinct=distinct, dialect=dialect, **opts)
class UDTF(DerivedTable):
@ -1697,7 +1697,7 @@ class RenameColumn(Expression):
arg_types = {"this": True, "to": True, "exists": False}
class RenameTable(Expression):
class AlterRename(Expression):
pass
@ -2400,6 +2400,7 @@ class Join(Expression):
"global": False,
"hint": False,
"match_condition": False, # Snowflake
"expressions": False,
}
@property
@ -2995,6 +2996,10 @@ class WithSystemVersioningProperty(Property):
}
class WithProcedureOptions(Property):
arg_types = {"expressions": True}
class Properties(Expression):
arg_types = {"expressions": True}
@ -3213,10 +3218,18 @@ class Table(Expression):
def to_column(self, copy: bool = True) -> Alias | Column | Dot:
parts = self.parts
last_part = parts[-1]
if isinstance(last_part, Identifier):
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
else:
# This branch will be reached if a function or array is wrapped in a `Table`
col = last_part
alias = self.args.get("alias")
if alias:
col = alias_(col, alias.this, copy=copy)
return col
@ -3278,7 +3291,7 @@ class Intersect(SetOperation):
pass
class Update(Expression):
class Update(DML):
arg_types = {
"with": False,
"this": False,
@ -3526,6 +3539,7 @@ class Select(Query):
"distinct": False,
"into": False,
"from": False,
"operation_modifiers": False,
**QUERY_MODIFIERS,
}
@ -5184,6 +5198,14 @@ class ToNumber(Func):
}
# https://docs.snowflake.com/en/sql-reference/functions/to_double
class ToDouble(Func):
arg_types = {
"this": True,
"format": False,
}
class Columns(Func):
arg_types = {"this": True, "unpack": False}
@ -5641,7 +5663,7 @@ class Exp(Func):
# https://docs.snowflake.com/en/sql-reference/functions/flatten
class Explode(Func):
class Explode(Func, UDTF):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -6248,6 +6270,11 @@ class Split(Func):
arg_types = {"this": True, "expression": True, "limit": False}
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html
class SplitPart(Func):
arg_types = {"this": True, "delimiter": True, "part_index": True}
# Start may be omitted in the case of postgres
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
@ -6857,26 +6884,37 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
return Paren(this=expression) if isinstance(expression, kind) else expression
def _apply_set_operation(
*expressions: ExpOrStr,
set_operation: t.Type[S],
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> S:
return reduce(
lambda x, y: set_operation(this=x, expression=y, distinct=distinct),
(maybe_parse(e, dialect=dialect, copy=copy, **opts) for e in expressions),
)
def union(
left: ExpOrStr,
right: ExpOrStr,
*expressions: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Union:
"""
Initializes a syntax tree from one UNION expression.
Initializes a syntax tree for the `UNION` operation.
Example:
>>> union("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings, corresponding to the `UNION`'s operands.
If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether to copy the expression.
@ -6885,32 +6923,29 @@ def union(
Returns:
The new Union instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Union(this=left, expression=right, distinct=distinct)
assert len(expressions) >= 2, "At least two expressions are required by `union`."
return _apply_set_operation(
*expressions, set_operation=Union, distinct=distinct, dialect=dialect, copy=copy, **opts
)
def intersect(
left: ExpOrStr,
right: ExpOrStr,
*expressions: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Intersect:
"""
Initializes a syntax tree from one INTERSECT expression.
Initializes a syntax tree for the `INTERSECT` operation.
Example:
>>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings, corresponding to the `INTERSECT`'s operands.
If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether to copy the expression.
@ -6919,32 +6954,29 @@ def intersect(
Returns:
The new Intersect instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Intersect(this=left, expression=right, distinct=distinct)
assert len(expressions) >= 2, "At least two expressions are required by `intersect`."
return _apply_set_operation(
*expressions, set_operation=Intersect, distinct=distinct, dialect=dialect, copy=copy, **opts
)
def except_(
left: ExpOrStr,
right: ExpOrStr,
*expressions: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Except:
"""
Initializes a syntax tree from one EXCEPT expression.
Initializes a syntax tree for the `EXCEPT` operation.
Example:
>>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql()
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
left: the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right: the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
expressions: the SQL code strings, corresponding to the `EXCEPT`'s operands.
If `Expression` instances are passed, they will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether to copy the expression.
@ -6953,10 +6985,10 @@ def except_(
Returns:
The new Except instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Except(this=left, expression=right, distinct=distinct)
assert len(expressions) >= 2, "At least two expressions are required by `except_`."
return _apply_set_operation(
*expressions, set_operation=Except, distinct=distinct, dialect=dialect, copy=copy, **opts
)
def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
@ -7410,15 +7442,9 @@ def to_interval(interval: str | Literal) -> Interval:
interval = interval.this
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
if not interval_parts:
raise ValueError("Invalid interval string.")
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2).upper()),
)
interval = maybe_parse(f"INTERVAL {interval}")
assert isinstance(interval, Interval)
return interval
def to_table(
@ -7795,7 +7821,7 @@ def rename_table(
this=old_table,
kind="TABLE",
actions=[
RenameTable(this=new_table),
AlterRename(this=new_table),
],
)

View file

@ -185,6 +185,7 @@ class Generator(metaclass=_Generator):
exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}",
exp.StreamingTableProperty: lambda *_: "STREAMING",
exp.StrictProperty: lambda *_: "STRICT",
exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}",
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
@ -200,6 +201,7 @@ class Generator(metaclass=_Generator):
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}",
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
@ -564,6 +566,7 @@ class Generator(metaclass=_Generator):
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA,
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
@ -2144,6 +2147,10 @@ class Generator(metaclass=_Generator):
this = expression.this
this_sql = self.sql(this)
exprs = self.expressions(expression)
if exprs:
this_sql = f"{this_sql},{self.seg(exprs)}"
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
@ -2510,13 +2517,16 @@ class Generator(metaclass=_Generator):
)
kind = ""
operation_modifiers = self.expressions(expression, key="operation_modifiers", sep=" ")
operation_modifiers = f"{self.sep()}{operation_modifiers}" if operation_modifiers else ""
# We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata
# are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first.
top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}"
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
f"SELECT{top_distinct}{kind}{expressions}",
f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@ -3225,12 +3235,12 @@ class Generator(metaclass=_Generator):
expressions = f"({expressions})" if expressions else ""
return f"ALTER{compound} SORTKEY {this or expressions}"
def renametable_sql(self, expression: exp.RenameTable) -> str:
def alterrename_sql(self, expression: exp.AlterRename) -> str:
if not self.RENAME_TABLE_WITH_DB:
# Remove db from tables
expression = expression.transform(
lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n
).assert_is(exp.RenameTable)
).assert_is(exp.AlterRename)
this = self.sql(expression, "this")
return f"RENAME TO {this}"
@ -3508,13 +3518,15 @@ class Generator(metaclass=_Generator):
name = self.normalize_func(name) if normalize else name
return f"{name}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
def format_args(self, *args: t.Optional[str | exp.Expression], sep: str = ", ") -> str:
arg_sqls = tuple(
self.sql(arg) for arg in args if arg is not None and not isinstance(arg, bool)
)
if self.pretty and self.too_wide(arg_sqls):
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
return self.indent(
"\n" + f"{sep.strip()}\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True
)
return sep.join(arg_sqls)
def too_wide(self, args: t.Iterable) -> bool:
return sum(len(arg) for arg in args) > self.max_text_width
@ -3612,7 +3624,7 @@ class Generator(metaclass=_Generator):
expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
)
return f"{this}{expressions}"
return f"{this}{expressions}" if expressions.strip() != "" else this
def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
@ -4243,7 +4255,7 @@ class Generator(metaclass=_Generator):
else:
rhs = self.expressions(expression)
return self.func(name, expression.this, rhs)
return self.func(name, expression.this, rhs or None)
def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
if self.SUPPORTS_CONVERT_TIMEZONE:
@ -4418,3 +4430,7 @@ class Generator(metaclass=_Generator):
for_sql = f" FOR {for_sql}" if for_sql else ""
return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})"
@unsupported_args("format")
def todouble_sql(self, expression: exp.ToDouble) -> str:
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))

View file

@ -56,6 +56,10 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
@t.overload
def ensure_list(value: None) -> t.List: ...
@t.overload
def ensure_list(value: T) -> t.List[T]: ...

View file

@ -287,15 +287,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType:
) -> exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
# We propagate the UNKNOWN type upwards if found
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.build("unknown")
return exp.DataType.Type.UNKNOWN
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
return t.cast(
exp.DataType.Type,
type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value,
)
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)

View file

@ -1,11 +1,18 @@
from __future__ import annotations
import itertools
import typing as t
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import Scope, build_scope
if t.TYPE_CHECKING:
ExistingCTEsMapping = t.Dict[exp.Expression, str]
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]
def eliminate_subqueries(expression):
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
"""
Rewrite derived tables as CTES, deduplicating if possible.
@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
taken: TakenNameMapping = {}
# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
@ -56,7 +63,7 @@ def eliminate_subqueries(expression):
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
existing_ctes: ExistingCTEsMapping = {}
with_ = root.expression.args.get("with")
recursive = False
@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
return expression
def _eliminate(scope, existing_ctes, taken):
def _eliminate(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)
return None
def _eliminate_derived_table(scope, existing_ctes, taken):
def _eliminate_derived_table(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
return cte
def _eliminate_cte(scope, existing_ctes, taken):
def _eliminate_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)
@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
return cte
def _new_cte(scope, existing_ctes, taken):
def _new_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Tuple[str, t.Optional[exp.Expression]]:
"""
Returns:
tuple of (name, cte)

View file

@ -1,11 +1,20 @@
from __future__ import annotations
import typing as t
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
if t.TYPE_CHECKING:
from sqlglot._typing import E
def merge_subqueries(expression, leave_tables_isolated=False):
FromOrJoin = t.Union[exp.From, exp.Join]
def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
@ -58,7 +67,7 @@ SAFE_TO_REPLACE_UNWRAPPED = (
)
def merge_ctes(expression, leave_tables_isolated=False):
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
scopes = traverse_scope(expression)
# All places where we select from CTEs.
@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
return expression
def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
def _mergeable(
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
) -> bool:
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
inner_select = inner_scope.expression.unnest()
@ -195,7 +198,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and inner_select.args.get("from") is not None
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
@ -218,19 +221,17 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
)
def _rename_inner_sources(outer_scope, inner_scope, alias):
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Renames any sources in the inner query that conflict with names in the outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
inner_taken = set(inner_scope.selected_sources)
outer_taken = set(outer_scope.selected_sources)
conflicts = outer_taken.intersection(inner_taken)
conflicts -= {alias}
taken = outer_taken.union(inner_taken)
for conflict in conflicts:
new_name = find_new_name(taken, conflict)
@ -250,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
def _merge_from(
outer_scope: Scope,
inner_scope: Scope,
node_to_replace: t.Union[exp.Subquery, exp.Table],
alias: str,
) -> None:
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins"))
@ -274,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
)
def _merge_joins(outer_scope, inner_scope, from_or_join):
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge JOIN clauses of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""
new_joins = []
@ -304,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
outer_scope.expression.set("joins", outer_joins)
def _merge_expressions(outer_scope, inner_scope, alias):
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Merge projections of inner query into outer query.
@ -338,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
column.replace(expression.copy())
def _merge_where(outer_scope, inner_scope, from_or_join):
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge WHERE clause of inner query into outer query.
@ -357,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {from_.alias_or_name} if from_ else {}
sources = {from_.alias_or_name} if from_ else set()
for join in expression.args["joins"]:
source = join.alias_or_name
@ -373,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
expression.where(where.this, copy=False)
def _merge_order(outer_scope, inner_scope):
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
"""
Merge ORDER clause of inner query into outer query.
@ -393,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _merge_hints(outer_scope, inner_scope):
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
@ -405,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
outer_scope.expression.set("hint", inner_scope_hint)
def _pop_cte(inner_scope):
def _pop_cte(inner_scope: Scope) -> None:
"""
Remove CTE from the AST.

View file

@ -27,6 +27,7 @@ def qualify(
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
allow_partial_qualification: bool = False,
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
@ -56,6 +57,7 @@ def qualify(
infer_schema: Whether to infer the schema if missing.
isolate_tables: Whether to isolate table selects.
qualify_columns: Whether to qualify columns.
allow_partial_qualification: Whether to allow partial qualification.
validate_qualify_columns: Whether to validate columns.
quote_identifiers: Whether to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
@ -90,6 +92,7 @@ def qualify(
expand_alias_refs=expand_alias_refs,
expand_stars=expand_stars,
infer_schema=infer_schema,
allow_partial_qualification=allow_partial_qualification,
)
if quote_identifiers:

View file

@ -22,6 +22,7 @@ def qualify_columns(
expand_alias_refs: bool = True,
expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
allow_partial_qualification: bool = False,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have fully qualified columns.
@ -41,6 +42,7 @@ def qualify_columns(
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
infer_schema: Whether to infer the schema if missing.
allow_partial_qualification: Whether to allow partial qualification.
Returns:
The qualified expression.
@ -68,7 +70,7 @@ def qualify_columns(
)
_convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver)
_qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification)
if not schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
@ -240,13 +242,21 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None:
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
is_group_by = isinstance(node, exp.Group)
if not node or (expand_only_groupby and not is_group_by):
return
for column in walk_in_scope(node, prune=lambda node: node.is_star):
if not isinstance(column, exp.Column):
continue
# BigQuery's GROUP BY allows alias expansion only for standalone names, e.g:
# SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded
# SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col))
# This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns
if expand_only_groupby and is_group_by and column.parent is not node:
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
@ -273,9 +283,8 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
if simplified is not column:
column.replace(simplified)
for i, projection in enumerate(scope.expression.selects):
for i, projection in enumerate(expression.selects):
replace_columns(projection)
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = (projection.this, i + 1)
@ -434,7 +443,7 @@ def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
@ -442,7 +451,12 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table and column_table in scope.sources:
source_columns = resolver.get_source_columns(column_table)
if source_columns and column_name not in source_columns and "*" not in source_columns:
if (
not allow_partial_qualification
and source_columns
and column_name not in source_columns
and "*" not in source_columns
):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
@ -526,7 +540,7 @@ def _expand_stars(
) -> None:
"""Expand stars to lists of column selections"""
new_selections = []
new_selections: t.List[exp.Expression] = []
except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {}
rename_columns: t.Dict[int, t.Dict[str, str]] = {}

View file

@ -562,8 +562,8 @@ def _traverse_scope(scope):
elif isinstance(expression, exp.DML):
yield from _traverse_ctes(scope)
for query in find_all_in_scope(expression, exp.Query):
# This check ensures we don't yield the CTE queries twice
if not isinstance(query.parent, exp.CTE):
# This check ensures we don't yield the CTE/nested queries twice
if not isinstance(query.parent, (exp.CTE, exp.Subquery)):
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
return
else:
@ -679,6 +679,8 @@ def _traverse_tables(scope):
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:
if isinstance(expression, exp.Final):
expression = expression.this
if isinstance(expression, exp.Table):
table_name = expression.name
source_name = expression.alias_or_name

View file

@ -206,6 +206,11 @@ COMPLEMENT_COMPARISONS = {
exp.NEQ: exp.EQ,
}
COMPLEMENT_SUBQUERY_PREDICATES = {
exp.All: exp.Any,
exp.Any: exp.All,
}
def simplify_not(expression):
"""
@ -218,9 +223,12 @@ def simplify_not(expression):
if is_null(this):
return exp.null()
if this.__class__ in COMPLEMENT_COMPARISONS:
return COMPLEMENT_COMPARISONS[this.__class__](
this=this.this, expression=this.expression
)
right = this.expression
complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
if complement_subquery_predicate:
right = complement_subquery_predicate(this=right.this)
return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):

View file

@ -1053,13 +1053,16 @@ class Parser(metaclass=_Parser):
ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(),
"AS": lambda self: self._parse_select(),
"ALTER": lambda self: self._parse_alter_table_alter(),
"CLUSTER BY": lambda self: self._parse_cluster(wrapped=True),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
"SET": lambda self: self._parse_alter_table_set(),
"AS": lambda self: self._parse_select(),
"SWAP": lambda self: self.expression(
exp.SwapTable, this=self._match(TokenType.WITH) and self._parse_table(schema=True)
),
}
ALTER_ALTER_PARSERS = {
@ -1222,6 +1225,10 @@ class Parser(metaclass=_Parser):
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
}
PROCEDURE_OPTIONS: OPTIONS_TYPE = {}
EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple())
KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = {
"NOT": ("ENFORCED",),
"MATCH": (
@ -1286,6 +1293,11 @@ class Parser(metaclass=_Parser):
PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN}
# The style options for the DESCRIBE statement
DESCRIBE_STYLES = {"ANALYZE", "EXTENDED", "FORMATTED", "HISTORY"}
OPERATION_MODIFIERS: t.Set[str] = set()
STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False
@ -2195,11 +2207,26 @@ class Parser(metaclass=_Parser):
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
)
if self._match_texts(self.PROCEDURE_OPTIONS, advance=False):
return self.expression(
exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option)
)
if not self._next:
return None
return self._parse_withisolatedloading()
def _parse_procedure_option(self) -> exp.Expression | None:
if self._match_text_seq("EXECUTE", "AS"):
return self.expression(
exp.ExecuteAsProperty,
this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False)
or self._parse_string(),
)
return self._parse_var_from_options(self.PROCEDURE_OPTIONS)
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
self._match(TokenType.EQ)
@ -2567,7 +2594,7 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper()
style = self._match_texts(self.DESCRIBE_STYLES) and self._prev.text.upper()
if self._match(TokenType.DOT):
style = None
self._retreat(self._index - 2)
@ -2955,6 +2982,10 @@ class Parser(metaclass=_Parser):
if all_ and distinct:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
operation_modifiers = []
while self._curr and self._match_texts(self.OPERATION_MODIFIERS):
operation_modifiers.append(exp.var(self._prev.text.upper()))
limit = self._parse_limit(top=True)
projections = self._parse_projections()
@ -2965,6 +2996,7 @@ class Parser(metaclass=_Parser):
distinct=distinct,
expressions=projections,
limit=limit,
operation_modifiers=operation_modifiers or None,
)
this.comments = comments
@ -3400,6 +3432,10 @@ class Parser(metaclass=_Parser):
return None
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA):
kwargs["expressions"] = self._parse_csv(
lambda: self._parse_table(parse_bracket=parse_bracket)
)
if method:
kwargs["method"] = method.text
@ -3420,7 +3456,7 @@ class Parser(metaclass=_Parser):
elif (
not (outer_apply or cross_apply)
and not isinstance(kwargs["this"], exp.Unnest)
and not (kind and kind.token_type == TokenType.CROSS)
and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY))
):
index = self._index
joins: t.Optional[list] = list(self._parse_joins())
@ -4470,7 +4506,7 @@ class Parser(metaclass=_Parser):
elif not self._match(TokenType.R_BRACKET, expression=this):
self.raise_error("Expecting ]")
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
this = self.expression(exp.In, this=this, field=self._parse_column())
return this
@ -5533,12 +5569,15 @@ class Parser(metaclass=_Parser):
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
else:
this = None
this = self._match(TokenType.CONSTRAINT) and self._parse_id_var()
if self._match_texts(self.CONSTRAINT_PARSERS):
procedure_option_follows = (
self._match(TokenType.WITH, advance=False)
and self._next
and self._next.text.upper() in self.PROCEDURE_OPTIONS
)
if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS):
return self.expression(
exp.ColumnConstraint,
this=this,
@ -6764,7 +6803,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
def _parse_alter_table_rename(self) -> t.Optional[exp.RenameTable | exp.RenameColumn]:
def _parse_alter_table_rename(self) -> t.Optional[exp.AlterRename | exp.RenameColumn]:
if self._match(TokenType.COLUMN):
exists = self._parse_exists()
old_column = self._parse_column()
@ -6777,7 +6816,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists)
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
return self.expression(exp.AlterRename, this=self._parse_table(schema=True))
def _parse_alter_table_set(self) -> exp.AlterSet:
alter_set = self.expression(exp.AlterSet)

View file

@ -107,6 +107,7 @@ LANGUAGE js AS
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
self.validate_identity("SELECT ARRAY_CONCAT([1])")
self.validate_identity("SELECT * FROM READ_CSV('bla.csv')")
self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
self.validate_identity("assert.true(1 = 1)")

View file

@ -2,6 +2,7 @@ from datetime import date
from sqlglot import exp, parse_one
from sqlglot.dialects import ClickHouse
from sqlglot.expressions import convert
from sqlglot.optimizer import traverse_scope
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ErrorLevel
@ -28,6 +29,7 @@ class TestClickhouse(Validator):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)
self.validate_identity("CAST(1 AS Bool)")
self.validate_identity("SELECT toString(CHAR(104.1, 101, 108.9, 108.9, 111, 32))")
self.validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_identity("SELECT toFloat(like)")
@ -420,11 +422,6 @@ class TestClickhouse(Validator):
" GROUP BY loyalty ORDER BY loyalty ASC"
},
)
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
self.validate_identity(
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
)
self.validate_all(
"SELECT quantile(0.5)(a)",
read={"duckdb": "SELECT quantile(a, 0.5)"},
@ -1100,3 +1097,36 @@ LIFETIME(MIN 0 MAX 0)""",
def test_grant(self):
self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION")
self.validate_identity("GRANT INSERT(x, y) ON db.table TO john")
def test_array_join(self):
expr = self.validate_identity(
"SELECT * FROM arrays_test ARRAY JOIN arr1, arrays_test.arr2 AS foo, ['a', 'b', 'c'] AS elem"
)
joins = expr.args["joins"]
self.assertEqual(len(joins), 1)
join = joins[0]
self.assertEqual(join.kind, "ARRAY")
self.assertIsInstance(join.this, exp.Column)
self.assertEqual(len(join.expressions), 2)
self.assertIsInstance(join.expressions[0], exp.Alias)
self.assertIsInstance(join.expressions[0].this, exp.Column)
self.assertIsInstance(join.expressions[1], exp.Alias)
self.assertIsInstance(join.expressions[1].this, exp.Array)
self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr")
self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a")
self.validate_identity(
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
)
self.validate_identity(
"SELECT * FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external1, ['a', 'b', 'c'] AS arr_external2, splitByString(',', 'asd,qwerty,zxc') AS arr_external3"
)
def test_traverse_scope(self):
sql = "SELECT * FROM t FINAL"
scopes = traverse_scope(parse_one(sql, dialect=self.dialect))
self.assertEqual(len(scopes), 1)
self.assertEqual(set(scopes[0].sources), {"t"})

View file

@ -7,6 +7,7 @@ class TestDatabricks(Validator):
dialect = "databricks"
def test_databricks(self):
self.validate_identity("SELECT t.current_time FROM t")
self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT")
self.validate_identity("DESCRIBE HISTORY a.b")
self.validate_identity("DESCRIBE history.tbl")

View file

@ -1762,6 +1762,7 @@ class TestDialect(Validator):
self.validate_all(
"LEVENSHTEIN(col1, col2)",
write={
"bigquery": "EDIT_DISTANCE(col1, col2)",
"duckdb": "LEVENSHTEIN(col1, col2)",
"drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
@ -1772,6 +1773,7 @@ class TestDialect(Validator):
self.validate_all(
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={
"bigquery": "EDIT_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",

View file

@ -256,6 +256,9 @@ class TestDuckDB(Validator):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
)
self.validate_identity("SELECT UNNEST([1, 2])").selects[0].assert_is(exp.UDTF)
self.validate_identity("'red' IN flags").args["field"].assert_is(exp.Column)
self.validate_identity("'red' IN tbl.flags")
self.validate_identity("CREATE TABLE tbl1 (u UNION(num INT, str TEXT))")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")

Some files were not shown because too many files have changed in this diff Show more