1
0
Fork 0

Merging upstream version 17.3.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:44:18 +01:00
parent 335ae02913
commit 133b8dfc8d
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
43 changed files with 5488 additions and 5047 deletions

View file

@ -174,6 +174,12 @@ def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
return expr_type.from_arg_list(args)
def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
# TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
arg = seq_get(args, 0)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
@ -275,6 +281,8 @@ class BigQuery(Dialect):
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
@ -379,22 +387,27 @@ class BigQuery(Dialect):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateStrToDate: datestrtodate_sql,
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpExtract: lambda self, e: self.func(
"REGEXP_EXTRACT",
e.this,
@ -403,6 +416,7 @@ class BigQuery(Dialect):
e.args.get("occurrence"),
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest,
@ -411,6 +425,9 @@ class BigQuery(Dialect):
_alias_ordered_group,
]
),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
@ -420,17 +437,12 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.VariancePop: rename_func("VAR_POP"),
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.Unhex: rename_func("FROM_HEX"),
exp.Values: _derived_table_values_to_unnest,
exp.VariancePop: rename_func("VAR_POP"),
}
TYPE_MAPPING = {

View file

@ -357,6 +357,7 @@ class Hive(Dialect):
exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,

View file

@ -263,6 +263,7 @@ class Postgres(Dialect):
"DO": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
"MONEY": TokenType.MONEY,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,

View file

@ -41,6 +41,12 @@ class Spark(Spark2):
}
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
}
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)

View file

@ -177,9 +177,6 @@ class Spark2(Hive):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {

View file

@ -192,7 +192,7 @@ class SQLite(Dialect):
if len(expression.expressions) > 1:
return rename_func("MIN")(self, expression)
return self.expressions(expression)
return self.sql(expression, "this")
def transaction_sql(self, expression: exp.Transaction) -> str:
this = expression.this

View file

@ -274,12 +274,16 @@ class Expression(metaclass=_Expression):
def set(self, arg_key: str, value: t.Any) -> None:
"""
Sets `arg_key` to `value`.
Sets arg_key to value.
Args:
arg_key (str): name of the expression arg.
arg_key: name of the expression arg.
value: value to set the arg to.
"""
if value is None:
self.args.pop(arg_key, None)
return
self.args[arg_key] = value
self._set_parent(arg_key, value)
@ -2278,6 +2282,7 @@ class Table(Expression):
"pivots": False,
"hints": False,
"system_time": False,
"wrapped": False,
}
@property
@ -4249,7 +4254,7 @@ class JSONArrayContains(Binary, Predicate, Func):
class Least(Func):
arg_types = {"expressions": False}
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -4342,6 +4347,11 @@ class MD5(Func):
_sql_names = ["MD5"]
# Represents the variant of the MD5 function that returns a binary value
class MD5Digest(Func):
_sql_names = ["MD5_DIGEST"]
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True

View file

@ -1215,7 +1215,8 @@ class Generator:
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
sql = f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
return f"({sql})" if expression.args.get("wrapped") else sql
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@ -2289,11 +2290,14 @@ class Generator:
def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
for arg_value in expression.args.values():
for key in expression.arg_types:
arg_value = expression.args.get(key)
if isinstance(arg_value, list):
for value in arg_value:
args.append(value)
else:
elif arg_value is not None:
args.append(arg_value)
return self.func(expression.sql_name(), *args)

View file

@ -15,8 +15,7 @@ def qualify_tables(
schema: t.Optional[Schema] = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Additionally, this
replaces "join constructs" (*) by equivalent SELECT * subqueries.
Rewrite sqlglot AST to have fully qualified, unnested tables.
Examples:
>>> import sqlglot
@ -24,9 +23,18 @@ def qualify_tables(
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl)")
>>> qualify_tables(expression).sql()
'SELECT * FROM tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2'
Note:
This rule effectively enforces a left-to-right join order, since all joins
are unnested. This means that the optimizer doesn't necessarily preserve the
original join order, e.g. when parentheses are used to specify it explicitly.
Args:
expression: Expression to qualify
@ -36,19 +44,11 @@ def qualify_tables(
Returns:
The qualified expression.
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
next_alias_name = name_sequence("_q_")
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
# Expand join construct
if isinstance(derived_table, exp.Subquery):
unnested = derived_table.unnest()
if isinstance(unnested, exp.Table):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
if not derived_table.args.get("alias"):
alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
@ -66,13 +66,17 @@ def qualify_tables(
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
# Unnest joins attached in tables by appending them to the closest query
for join in source.args.get("joins") or []:
scope.expression.append("joins", join)
source.set("joins", None)
source.set("wrapped", None)
if not source.alias:
source = source.replace(
alias(
source,
name or source.name or next_alias_name(),
copy=True,
table=True,
source, name or source.name or next_alias_name(), copy=True, table=True
)
)

View file

@ -548,9 +548,6 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
# This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@ -632,8 +629,9 @@ def _traverse_tables(scope):
if from_:
expressions.append(from_.this)
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
for expression in (scope.expression, *scope.find_all(exp.Table)):
for join in expression.args.get("joins") or []:
expressions.append(join.this)
if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)

View file

@ -1969,10 +1969,31 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
alias = None
# Ensure "wrapped" tables are not parsed as Subqueries. The exception to this is when there's
# an alias that can be applied to the parentheses, because that would shadow all wrapped table
# names, and so we want to parse it as a Subquery to represent the inner scope appropriately.
# Additionally, we want the node under the Subquery to be an actual query, so we will replace
# the table reference with a star query that selects from it.
if isinstance(this, exp.Table):
alias = self._parse_table_alias()
if not alias:
this.set("wrapped", True)
return this
this.set("wrapped", None)
joins = this.args.pop("joins", None)
this = this.replace(exp.select("*").from_(this.copy(), copy=False))
this.set("joins", joins)
subquery = self._parse_subquery(this, parse_alias=parse_subquery_alias and not alias)
if subquery and alias:
subquery.set("alias", alias)
# We return early here so that the UNION isn't attached to the subquery by the
# following call to _parse_set_operations, but instead becomes the parent node
return subquery
elif self._match(TokenType.VALUES):
this = self.expression(
exp.Values,
@ -2292,6 +2313,7 @@ class Parser(metaclass=_Parser):
else:
joins = None
self._retreat(index)
kwargs["this"].set("joins", joins)
return self.expression(exp.Join, **kwargs)