1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,4 +1,8 @@
from __future__ import annotations
import logging
import re
import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
NEWLINE_RE = re.compile("\r\n?|\n")
class Generator:
"""
@ -47,8 +53,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
comments: Whether or not to preserve comments in the ouput SQL code.
Default: True
"""
@ -65,14 +70,16 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
# can override in dialects
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
# whether or not null ordering is supported in order by
# Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
# Always do union distinct or union all
EXPLICIT_UNION = False
# wrap derived values in parens, usually standard but spark doesn't support it
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = {
@ -80,7 +87,7 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
}
TOKEN_MAPPING = {}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@ -96,6 +103,8 @@ class Generator:
exp.TableFormatProperty,
}
WITH_SEPARATED_COMMENTS = (exp.Select,)
__slots__ = (
"time_mapping",
"time_trie",
@ -122,7 +131,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
"_annotations",
"_comments",
)
def __init__(
@ -148,7 +157,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
annotations=True,
comments=True,
):
import sqlglot
@ -177,7 +186,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._annotations = annotations
self._comments = comments
def generate(self, expression):
"""
@ -204,7 +213,6 @@ class Generator:
return sql
def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@ -215,9 +223,31 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
def maybe_comment(self, sql, expression, single_line=False):
comment = expression.comment if self._comments else None
if not comment:
return sql
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"/*{comment}*/{self.sep()}{sql}"
if not self.pretty:
return f"{sql} /*{comment}*/"
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
return f"/*{comment}*/\n{sql}"
def wrap(self, expression):
this_sql = self.indent(
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
else self.sql(expression, "this"),
level=1,
pad=0,
)
@ -251,7 +281,7 @@ class Generator:
for i, line in enumerate(lines)
)
def sql(self, expression, key=None):
def sql(self, expression, key=None, comment=True):
if not expression:
return ""
@ -264,29 +294,24 @@ class Generator:
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
return transform(self, expression)
if transform:
return transform
sql = transform(self, expression)
elif transform:
sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
if not isinstance(expression, exp.Expression):
if hasattr(self, exp_handler_name):
sql = getattr(self, exp_handler_name)(expression)
elif isinstance(expression, exp.Func):
sql = self.function_fallback_sql(expression)
elif isinstance(expression, exp.Property):
sql = self.property_sql(expression)
else:
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
exp_handler_name = f"{expression.key}_sql"
if hasattr(self, exp_handler_name):
return getattr(self, exp_handler_name)(expression)
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
if self._annotations and self.pretty:
return f"{self.sql(expression, 'expression')} # {expression.name}"
return self.sql(expression, "expression")
return self.maybe_comment(sql, expression) if self._comments and comment else sql
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@ -371,7 +396,9 @@ class Generator:
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@ -434,7 +461,9 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
@ -481,15 +510,18 @@ class Generator:
return f"{this} ON {table} {columns}"
def identifier_sql(self, expression):
value = expression.name
value = value.lower() if self.normalize else value
text = expression.name
text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
return f"{self.identifier_start}{value}{self.identifier_end}"
return value
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
def partition_sql(self, expression):
keys = csv(
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
for prop in expression.this
]
)
return f"PARTITION({keys})"
@ -504,9 +536,9 @@ class Generator:
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties)
)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties):
if properties.expressions:
@ -551,7 +583,9 @@ class Generator:
this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " "
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
partition_sql = (
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
@ -669,7 +703,9 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@ -711,10 +747,10 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
def lambda_sql(self, expression):
def lambda_sql(self, expression, arrow_sep="->"):
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
def lateral_sql(self, expression):
this = self.sql(expression, "this")
@ -748,7 +784,7 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
return f"{self.quote_start}{text}{self.quote_end}"
text = f"{self.quote_start}{text}{self.quote_end}"
return text
def loaddata_sql(self, expression):
@ -796,13 +832,21 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
if nulls_first and (
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
):
nulls_sort_change = " NULLS FIRST"
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
elif (
nulls_last
and ((asc and nulls_are_small) or (desc and nulls_are_large))
and not nulls_are_last
):
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
self.unsupported(
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
)
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@ -835,7 +879,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
self.sql(expression, "from"),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@ -858,6 +902,13 @@ class Generator:
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
def sessionparameter_sql(self, expression):
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"
@ -931,7 +982,10 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
end = (
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
or "CURRENT ROW"
)
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@ -1020,7 +1074,9 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
def in_sql(self, expression):
query = expression.args.get("query")
@ -1196,6 +1252,12 @@ class Generator:
def neq_sql(self, expression):
return self.binary(expression, "<>")
def nullsafeeq_sql(self, expression):
return self.binary(expression, "IS NOT DISTINCT FROM")
def nullsafeneq_sql(self, expression):
return self.binary(expression, "IS DISTINCT FROM")
def or_sql(self, expression):
return self.connector_sql(expression, "OR")
@ -1205,6 +1267,9 @@ class Generator:
def trycast_sql(self, expression):
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def use_sql(self, expression):
return f"USE {self.sql(expression, 'this')}"
def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
@ -1240,17 +1305,27 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
sql = (self.sql(e) for e in expressions)
# the only time leading_comma changes the output is if pretty print is enabled
if self._leading_comma and self.pretty:
pad = " " * self.pad
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
else:
expressions = self.sep(sep).join(sql)
num_sqls = len(expressions)
if indent:
return self.indent(expressions, skip_first=False)
return expressions
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
pad = " " * self.pad
stripped_sep = sep.strip()
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
comment = self.maybe_comment("", e, single_line=True)
if self.pretty:
if self._leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
else:
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
else:
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
def op_expressions(self, op, expression, flat=False):
expressions_sql = self.expressions(expression, flat=flat)
@ -1264,7 +1339,9 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
@ -1283,3 +1360,6 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
def kwarg_sql(self, expression):
return self.binary(expression, "=>")