1
0
Fork 0

Merging upstream version 19.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:09 +01:00
parent 348b067e1b
commit 89acb78953
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
91 changed files with 45416 additions and 43096 deletions

View file

@ -230,6 +230,12 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
# Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
SUPPORTS_NESTED_CTES = True
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -304,6 +310,7 @@ class Generator:
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
@ -407,7 +414,6 @@ class Generator:
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_cache",
)
def __init__(
@ -447,30 +453,38 @@ class Generator:
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
expression: t.Optional[exp.Expression],
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
cache: An optional sql string cache. This leverages the hash of an Expression
which can be slow to compute, so only use it if you set _hash on each node.
copy: Whether or not to copy the expression. The generator performs mutations so
it is safer to copy.
Returns:
The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
if copy:
expression = expression.copy()
# Some dialects only support CTEs at the top level expression, so we need to bubble up nested
# CTEs to that level in order to produce a syntactically valid expression. This transformation
# happens here to minimize code duplication, since many expressions support CTEs.
if (
not self.SUPPORTS_NESTED_CTES
and isinstance(expression, exp.Expression)
and not expression.parent
and "with" in expression.arg_types
and any(node.parent is not expression for node in expression.find_all(exp.With))
):
from sqlglot.transforms import move_ctes_to_top_level
expression = move_ctes_to_top_level(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@ -595,12 +609,6 @@ class Generator:
return self.sql(value)
return ""
if self._cache is not None:
expression_id = hash(expression)
if expression_id in self._cache:
return self._cache[expression_id]
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@ -621,11 +629,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
return sql
return self.maybe_comment(sql, expression) if self.comments and comment else sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@ -879,7 +883,11 @@ class Generator:
def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
recursive = (
"RECURSIVE "
if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive")
else ""
)
return f"WITH {recursive}{sql}"
@ -1022,7 +1030,7 @@ class Generator:
where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
agg = expression.this.copy()
agg = expression.this
agg_arg = agg.this
cond = expression.expression.this
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
@ -1088,9 +1096,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
with_properties.append(p.copy())
with_properties.append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA:
root_properties.append(p.copy())
root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
@ -1124,7 +1132,7 @@ class Generator:
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc != exp.Properties.Location.UNSUPPORTED:
properties_locs[p_loc].append(p.copy())
properties_locs[p_loc].append(p)
else:
self.unsupported(f"Unsupported property {p.key}")
@ -1238,6 +1246,29 @@ class Generator:
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
if isinstance(expression.this, list):
return f"IN ({self.expressions(expression, key='this', flat=True)})"
if expression.this:
modulus = self.sql(expression, "this")
remainder = self.sql(expression, "expression")
return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"
from_expressions = self.expressions(expression, key="from_expressions", flat=True)
to_expressions = self.expressions(expression, key="to_expressions", flat=True)
return f"FROM ({from_expressions}) TO ({to_expressions})"
def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
this = self.sql(expression, "this")
for_values_or_default = expression.expression
if isinstance(for_values_or_default, exp.PartitionBoundSpec):
for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
else:
for_values_or_default = " DEFAULT"
return f"PARTITION OF {this}{for_values_or_default}"
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
@ -1385,7 +1416,12 @@ class Generator:
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
ordinality = expression.args.get("ordinality") or ""
if ordinality:
ordinality = f" WITH ORDINALITY{alias}"
alias = ""
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@ -1489,7 +1525,6 @@ class Generator:
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
expression = expression.copy()
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns
@ -1972,8 +2007,7 @@ class Generator:
if self.UNNEST_WITH_ORDINALITY:
if alias and isinstance(offset, exp.Expression):
alias = alias.copy()
alias.append("columns", offset.copy())
alias.append("columns", offset)
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
@ -2138,7 +2172,6 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@ -2367,7 +2400,9 @@ class Generator:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
to_sql = self.sql(expression, "to")
to_sql = f" {to_sql}" if to_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@ -2510,7 +2545,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
@ -2779,7 +2814,6 @@ class Generator:
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
@ -2787,7 +2821,9 @@ class Generator:
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ")
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
return self.prepend_ctes(
expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
)
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@ -2896,12 +2932,12 @@ class Generator:
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
expression.args["true"].copy(),
expression.args["true"],
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
case.else_(else_cond.copy(), copy=False)
case.else_(else_cond, copy=False)
return self.sql(case)
@ -2931,15 +2967,6 @@ class Generator:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression.copy())
expression = simplify(expression)
return expression
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
) -> t.Callable[[exp.Expression], str]:
"""Returns a cached generator."""
cache = {} if cache is None else cache
generator = Generator(normalize=True, identify="safe")
return lambda e: generator.generate(e, cache)