Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
528822bfd4
commit
b7d21c45b7
98 changed files with 4080 additions and 1666 deletions
|
@ -1,4 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
|
||||
|
@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
NEWLINE_RE = re.compile("\r\n?|\n")
|
||||
|
||||
|
||||
class Generator:
|
||||
"""
|
||||
|
@ -47,8 +53,7 @@ class Generator:
|
|||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
line length.
|
||||
Default: 80
|
||||
annotations: Whether or not to show annotations in the SQL when `pretty` is True.
|
||||
Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
|
||||
comments: Whether or not to preserve comments in the ouput SQL code.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
|
@ -65,14 +70,16 @@ class Generator:
|
|||
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
|
||||
}
|
||||
|
||||
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
# can override in dialects
|
||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||
CREATE_TRANSIENT = False
|
||||
# whether or not null ordering is supported in order by
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
# always do union distinct or union all
|
||||
|
||||
# Always do union distinct or union all
|
||||
EXPLICIT_UNION = False
|
||||
# wrap derived values in parens, usually standard but spark doesn't support it
|
||||
|
||||
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||
WRAP_DERIVED_VALUES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -80,7 +87,7 @@ class Generator:
|
|||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
}
|
||||
|
||||
TOKEN_MAPPING = {}
|
||||
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
|
||||
|
||||
STRUCT_DELIMITER = ("<", ">")
|
||||
|
||||
|
@ -96,6 +103,8 @@ class Generator:
|
|||
exp.TableFormatProperty,
|
||||
}
|
||||
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select,)
|
||||
|
||||
__slots__ = (
|
||||
"time_mapping",
|
||||
"time_trie",
|
||||
|
@ -122,7 +131,7 @@ class Generator:
|
|||
"_escaped_quote_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
"_annotations",
|
||||
"_comments",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -148,7 +157,7 @@ class Generator:
|
|||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
annotations=True,
|
||||
comments=True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
|
@ -177,7 +186,7 @@ class Generator:
|
|||
self._escaped_quote_end = self.escape + self.quote_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
self._annotations = annotations
|
||||
self._comments = comments
|
||||
|
||||
def generate(self, expression):
|
||||
"""
|
||||
|
@ -204,7 +213,6 @@ class Generator:
|
|||
return sql
|
||||
|
||||
def unsupported(self, message):
|
||||
|
||||
if self.unsupported_level == ErrorLevel.IMMEDIATE:
|
||||
raise UnsupportedError(message)
|
||||
self.unsupported_messages.append(message)
|
||||
|
@ -215,9 +223,31 @@ class Generator:
|
|||
def seg(self, sql, sep=" "):
|
||||
return f"{self.sep(sep)}{sql}"
|
||||
|
||||
def maybe_comment(self, sql, expression, single_line=False):
|
||||
comment = expression.comment if self._comments else None
|
||||
|
||||
if not comment:
|
||||
return sql
|
||||
|
||||
comment = " " + comment if comment[0].strip() else comment
|
||||
comment = comment + " " if comment[-1].strip() else comment
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return f"/*{comment}*/{self.sep()}{sql}"
|
||||
|
||||
if not self.pretty:
|
||||
return f"{sql} /*{comment}*/"
|
||||
|
||||
if not NEWLINE_RE.search(comment):
|
||||
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
|
||||
|
||||
return f"/*{comment}*/\n{sql}"
|
||||
|
||||
def wrap(self, expression):
|
||||
this_sql = self.indent(
|
||||
self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
|
||||
self.sql(expression)
|
||||
if isinstance(expression, (exp.Select, exp.Union))
|
||||
else self.sql(expression, "this"),
|
||||
level=1,
|
||||
pad=0,
|
||||
)
|
||||
|
@ -251,7 +281,7 @@ class Generator:
|
|||
for i, line in enumerate(lines)
|
||||
)
|
||||
|
||||
def sql(self, expression, key=None):
|
||||
def sql(self, expression, key=None, comment=True):
|
||||
if not expression:
|
||||
return ""
|
||||
|
||||
|
@ -264,29 +294,24 @@ class Generator:
|
|||
transform = self.TRANSFORMS.get(expression.__class__)
|
||||
|
||||
if callable(transform):
|
||||
return transform(self, expression)
|
||||
if transform:
|
||||
return transform
|
||||
sql = transform(self, expression)
|
||||
elif transform:
|
||||
sql = transform
|
||||
elif isinstance(expression, exp.Expression):
|
||||
exp_handler_name = f"{expression.key}_sql"
|
||||
|
||||
if not isinstance(expression, exp.Expression):
|
||||
if hasattr(self, exp_handler_name):
|
||||
sql = getattr(self, exp_handler_name)(expression)
|
||||
elif isinstance(expression, exp.Func):
|
||||
sql = self.function_fallback_sql(expression)
|
||||
elif isinstance(expression, exp.Property):
|
||||
sql = self.property_sql(expression)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
else:
|
||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||
|
||||
exp_handler_name = f"{expression.key}_sql"
|
||||
if hasattr(self, exp_handler_name):
|
||||
return getattr(self, exp_handler_name)(expression)
|
||||
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
if isinstance(expression, exp.Property):
|
||||
return self.property_sql(expression)
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
|
||||
|
||||
def annotation_sql(self, expression):
|
||||
if self._annotations and self.pretty:
|
||||
return f"{self.sql(expression, 'expression')} # {expression.name}"
|
||||
return self.sql(expression, "expression")
|
||||
return self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||
|
||||
def uncache_sql(self, expression):
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -371,7 +396,9 @@ class Generator:
|
|||
expression_sql = self.sql(expression, "expression")
|
||||
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
|
||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||
transient = (
|
||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||
)
|
||||
replace = " OR REPLACE" if expression.args.get("replace") else ""
|
||||
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
|
||||
unique = " UNIQUE" if expression.args.get("unique") else ""
|
||||
|
@ -434,7 +461,9 @@ class Generator:
|
|||
def delete_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
using_sql = (
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
|
||||
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
|
||||
if expression.args.get("using")
|
||||
else ""
|
||||
)
|
||||
where_sql = self.sql(expression, "where")
|
||||
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
|
||||
|
@ -481,15 +510,18 @@ class Generator:
|
|||
return f"{this} ON {table} {columns}"
|
||||
|
||||
def identifier_sql(self, expression):
|
||||
value = expression.name
|
||||
value = value.lower() if self.normalize else value
|
||||
text = expression.name
|
||||
text = text.lower() if self.normalize else text
|
||||
if expression.args.get("quoted") or self.identify:
|
||||
return f"{self.identifier_start}{value}{self.identifier_end}"
|
||||
return value
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
def partition_sql(self, expression):
|
||||
keys = csv(
|
||||
*[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
|
||||
*[
|
||||
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
|
||||
for prop in expression.this
|
||||
]
|
||||
)
|
||||
return f"PARTITION({keys})"
|
||||
|
||||
|
@ -504,9 +536,9 @@ class Generator:
|
|||
elif p_class in self.ROOT_PROPERTIES:
|
||||
root_properties.append(p)
|
||||
|
||||
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
|
||||
exp.Properties(expressions=with_properties)
|
||||
)
|
||||
return self.root_properties(
|
||||
exp.Properties(expressions=root_properties)
|
||||
) + self.with_properties(exp.Properties(expressions=with_properties))
|
||||
|
||||
def root_properties(self, properties):
|
||||
if properties.expressions:
|
||||
|
@ -551,7 +583,9 @@ class Generator:
|
|||
|
||||
this = f"{this}{self.sql(expression, 'this')}"
|
||||
exists = " IF EXISTS " if expression.args.get("exists") else " "
|
||||
partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
partition_sql = (
|
||||
self.sql(expression, "partition") if expression.args.get("partition") else ""
|
||||
)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
sep = self.sep() if partition_sql else ""
|
||||
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
|
||||
|
@ -669,7 +703,9 @@ class Generator:
|
|||
def group_sql(self, expression):
|
||||
group_by = self.op_expressions("GROUP BY", expression)
|
||||
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
|
||||
grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||
grouping_sets = (
|
||||
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
|
||||
)
|
||||
cube = self.expressions(expression, key="cube", indent=False)
|
||||
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
|
||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||
|
@ -711,10 +747,10 @@ class Generator:
|
|||
this_sql = self.sql(expression, "this")
|
||||
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
|
||||
|
||||
def lambda_sql(self, expression):
|
||||
def lambda_sql(self, expression, arrow_sep="->"):
|
||||
args = self.expressions(expression, flat=True)
|
||||
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||
return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
|
||||
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
|
||||
|
||||
def lateral_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -748,7 +784,7 @@ class Generator:
|
|||
if self._replace_backslash:
|
||||
text = text.replace("\\", "\\\\")
|
||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||
return f"{self.quote_start}{text}{self.quote_end}"
|
||||
text = f"{self.quote_start}{text}{self.quote_end}"
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression):
|
||||
|
@ -796,13 +832,21 @@ class Generator:
|
|||
|
||||
sort_order = " DESC" if desc else ""
|
||||
nulls_sort_change = ""
|
||||
if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
|
||||
if nulls_first and (
|
||||
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
|
||||
):
|
||||
nulls_sort_change = " NULLS FIRST"
|
||||
elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
|
||||
elif (
|
||||
nulls_last
|
||||
and ((asc and nulls_are_small) or (desc and nulls_are_large))
|
||||
and not nulls_are_last
|
||||
):
|
||||
nulls_sort_change = " NULLS LAST"
|
||||
|
||||
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
|
||||
self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
|
||||
self.unsupported(
|
||||
"Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
|
||||
)
|
||||
nulls_sort_change = ""
|
||||
|
||||
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
|
||||
|
@ -835,7 +879,7 @@ class Generator:
|
|||
sql = self.query_modifiers(
|
||||
expression,
|
||||
f"SELECT{hint}{distinct}{expressions}",
|
||||
self.sql(expression, "from"),
|
||||
self.sql(expression, "from", comment=False),
|
||||
)
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
|
@ -858,6 +902,13 @@ class Generator:
|
|||
def parameter_sql(self, expression):
|
||||
return f"@{self.sql(expression, 'this')}"
|
||||
|
||||
def sessionparameter_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
kind = expression.text("kind")
|
||||
if kind:
|
||||
kind = f"{kind}."
|
||||
return f"@@{kind}{this}"
|
||||
|
||||
def placeholder_sql(self, expression):
|
||||
return f":{expression.name}" if expression.name else "?"
|
||||
|
||||
|
@ -931,7 +982,10 @@ class Generator:
|
|||
def window_spec_sql(self, expression):
|
||||
kind = self.sql(expression, "kind")
|
||||
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
|
||||
end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
|
||||
end = (
|
||||
csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
|
||||
or "CURRENT ROW"
|
||||
)
|
||||
return f"{kind} BETWEEN {start} AND {end}"
|
||||
|
||||
def withingroup_sql(self, expression):
|
||||
|
@ -1020,7 +1074,9 @@ class Generator:
|
|||
return f"UNIQUE ({columns})"
|
||||
|
||||
def if_sql(self, expression):
|
||||
return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
|
||||
return self.case_sql(
|
||||
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||
)
|
||||
|
||||
def in_sql(self, expression):
|
||||
query = expression.args.get("query")
|
||||
|
@ -1196,6 +1252,12 @@ class Generator:
|
|||
def neq_sql(self, expression):
|
||||
return self.binary(expression, "<>")
|
||||
|
||||
def nullsafeeq_sql(self, expression):
|
||||
return self.binary(expression, "IS NOT DISTINCT FROM")
|
||||
|
||||
def nullsafeneq_sql(self, expression):
|
||||
return self.binary(expression, "IS DISTINCT FROM")
|
||||
|
||||
def or_sql(self, expression):
|
||||
return self.connector_sql(expression, "OR")
|
||||
|
||||
|
@ -1205,6 +1267,9 @@ class Generator:
|
|||
def trycast_sql(self, expression):
|
||||
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||
|
||||
def use_sql(self, expression):
|
||||
return f"USE {self.sql(expression, 'this')}"
|
||||
|
||||
def binary(self, expression, op):
|
||||
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
|
||||
|
||||
|
@ -1240,17 +1305,27 @@ class Generator:
|
|||
if flat:
|
||||
return sep.join(self.sql(e) for e in expressions)
|
||||
|
||||
sql = (self.sql(e) for e in expressions)
|
||||
# the only time leading_comma changes the output is if pretty print is enabled
|
||||
if self._leading_comma and self.pretty:
|
||||
pad = " " * self.pad
|
||||
expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
|
||||
else:
|
||||
expressions = self.sep(sep).join(sql)
|
||||
num_sqls = len(expressions)
|
||||
|
||||
if indent:
|
||||
return self.indent(expressions, skip_first=False)
|
||||
return expressions
|
||||
# These are calculated once in case we have the leading_comma / pretty option set, correspondingly
|
||||
pad = " " * self.pad
|
||||
stripped_sep = sep.strip()
|
||||
|
||||
result_sqls = []
|
||||
for i, e in enumerate(expressions):
|
||||
sql = self.sql(e, comment=False)
|
||||
comment = self.maybe_comment("", e, single_line=True)
|
||||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
|
||||
|
||||
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
|
||||
|
||||
def op_expressions(self, op, expression, flat=False):
|
||||
expressions_sql = self.expressions(expression, flat=flat)
|
||||
|
@ -1264,7 +1339,9 @@ class Generator:
|
|||
def set_operation(self, expression, op):
|
||||
this = self.sql(expression, "this")
|
||||
op = self.seg(op)
|
||||
return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
|
||||
return self.query_modifiers(
|
||||
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
|
||||
)
|
||||
|
||||
def token_sql(self, token_type):
|
||||
return self.TOKEN_MAPPING.get(token_type, token_type.name)
|
||||
|
@ -1283,3 +1360,6 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
return f"{this}({expressions})"
|
||||
|
||||
def kwarg_sql(self, expression):
|
||||
return self.binary(expression, "=>")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue