1
0
Fork 0

Merging upstream version 6.1.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 08:04:41 +01:00
parent 3c6d649c90
commit 08ecea3adf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
61 changed files with 1844 additions and 1555 deletions

View file

@ -41,6 +41,8 @@ class Generator:
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma (bool): if the the comma is leading or trailing in select statements
Default: False
"""
TRANSFORMS = {
@ -108,6 +110,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
"_leading_comma",
)
def __init__(
@ -131,6 +134,7 @@ class Generator:
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
leading_comma=False,
):
import sqlglot
@ -157,6 +161,7 @@ class Generator:
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
def generate(self, expression):
"""
@ -178,9 +183,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(
concat_errors(self.unsupported_messages, self.max_unsupported)
)
raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
return sql
@ -197,9 +200,7 @@ class Generator:
def wrap(self, expression):
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
level=1,
pad=0,
)
@ -251,9 +252,7 @@ class Generator:
return transform
if not isinstance(expression, exp.Expression):
raise ValueError(
f"Expected an Expression. Received {type(expression)}: {expression}"
)
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
@ -276,11 +275,7 @@ class Generator:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
options = (
f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
if options
else ""
)
options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else ""
sql = self.sql(expression, "expression")
sql = f" AS{self.sep()}{sql}" if sql else ""
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
@ -306,9 +301,7 @@ class Generator:
def columndef_sql(self, expression):
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(
expression, key="constraints", sep=" ", flat=True
)
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
if not constraints:
return f"{column} {kind}"
@ -338,6 +331,9 @@ class Generator:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
def generatedasidentitycolumnconstraint_sql(self, expression):
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
@ -384,7 +380,10 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression):
return f"b'{self.sql(expression, 'this')}'"
return self.sql(expression, "this")
def hexstring_sql(self, expression):
return self.sql(expression, "this")
def datatype_sql(self, expression):
type_value = expression.this
@ -452,10 +451,7 @@ class Generator:
def partition_sql(self, expression):
keys = csv(
*[
f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
for k, v in expression.args.get("this")
]
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
)
return f"PARTITION({keys})"
@ -470,9 +466,9 @@ class Generator:
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties)
)
def root_properties(self, properties):
if properties.expressions:
@ -508,11 +504,7 @@ class Generator:
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
this = self.sql(expression, "this")
exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = (
self.sql(expression, "partition")
if expression.args.get("partition")
else ""
)
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
@ -531,7 +523,7 @@ class Generator:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
def table_sql(self, expression):
return ".".join(
table = ".".join(
part
for part in [
self.sql(expression, "catalog"),
@ -541,6 +533,10 @@ class Generator:
if part
)
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
return f"{table}{laterals}{joins}"
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
this = self.sql(expression.this, "this")
@ -586,11 +582,7 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
if grouping_sets
else ""
)
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@ -603,7 +595,16 @@ class Generator:
def join_sql(self, expression):
op_sql = self.seg(
" ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
" ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
"JOIN",
)
if op
)
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")
@ -630,9 +631,9 @@ class Generator:
def lateral_sql(self, expression):
this = self.sql(expression, "this")
op_sql = self.seg(
f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
)
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
@ -688,21 +689,13 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
nulls_sort_change = " NULLS FIRST"
elif (
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@ -798,14 +791,20 @@ class Generator:
def window_sql(self, expression):
this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else ""
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec")
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
if expression.arg_key == "window":
this = this = f"{self.seg('WINDOW')} {this} AS"
else:
@ -818,13 +817,8 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
start = csv(
self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
)
end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@ -879,6 +873,17 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression):
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return f"LTRIM({target})"
elif trim_type == "TRAILING":
return f"RTRIM({target})"
else:
return f"TRIM({target})"
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@ -898,9 +903,7 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
return self.case_sql(
exp.Case(ifs=[expression], default=expression.args.get("false"))
)
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def in_sql(self, expression):
query = expression.args.get("query")
@ -917,7 +920,9 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression):
return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
def reference_sql(self, expression):
this = self.sql(expression, "this")
@ -925,9 +930,7 @@ class Generator:
return f"REFERENCES {this}({expressions})"
def anonymous_sql(self, expression):
args = self.indent(
self.expressions(expression, flat=True), skip_first=True, skip_last=True
)
args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
def paren_sql(self, expression):
@ -1006,6 +1009,9 @@ class Generator:
def ignorenulls_sql(self, expression):
return f"{self.sql(expression, 'this')} IGNORE NULLS"
def respectnulls_sql(self, expression):
return f"{self.sql(expression, 'this')} RESPECT NULLS"
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
@ -1023,6 +1029,9 @@ class Generator:
def div_sql(self, expression):
return self.binary(expression, "/")
def distance_sql(self, expression):
return self.binary(expression, "<->")
def dot_sql(self, expression):
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
@ -1047,6 +1056,9 @@ class Generator:
def like_sql(self, expression):
return self.binary(expression, "LIKE")
def similarto_sql(self, expression):
return self.binary(expression, "SIMILAR TO")
def lt_sql(self, expression):
return self.binary(expression, "<")
@ -1069,14 +1081,10 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression):
return (
f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
)
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def binary(self, expression, op):
return (
f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
)
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression):
args = []
@ -1089,9 +1097,7 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({args_str})"
def format_time(self, expression):
return format_time(
self.sql(expression, "format"), self.time_mapping, self.time_trie
)
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
expressions = expression.args.get(key or "expressions")
@ -1102,7 +1108,14 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
expressions = self.sep(sep).join(self.sql(e) for e in expressions)
sql = (self.sql(e) for e in expressions)
# the only time leading_comma changes the output is if pretty print is enabled
if self._leading_comma and self.pretty:
pad = " " * self.pad
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
else:
expressions = self.sep(sep).join(sql)
if indent:
return self.indent(expressions, skip_first=False)
return expressions
@ -1116,9 +1129,7 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)