Merging upstream version 19.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
348b067e1b
commit
89acb78953
91 changed files with 45416 additions and 43096 deletions
|
@ -158,6 +158,6 @@ def transpile(
|
|||
"""
|
||||
write = (read if write is None else write) if identity else write
|
||||
return [
|
||||
Dialect.get_or_raise(write)().generate(expression, **opts)
|
||||
Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else ""
|
||||
for expression in parse(sql, read, error_level=error_level)
|
||||
]
|
||||
|
|
|
@ -69,7 +69,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
|
|||
returns = expression.find(exp.ReturnsProperty)
|
||||
|
||||
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
|
||||
expression = expression.copy()
|
||||
expression.set("kind", "TABLE FUNCTION")
|
||||
|
||||
if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
|
||||
|
@ -699,6 +698,5 @@ class BigQuery(Dialect):
|
|||
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
if expression.name == "TIMESTAMP":
|
||||
expression = expression.copy()
|
||||
expression.set("this", "SYSTEM_TIME")
|
||||
return super().version_sql(expression)
|
||||
|
|
|
@ -461,7 +461,6 @@ class ClickHouse(Dialect):
|
|||
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
# Clickhouse errors out if we try to cast a NULL value to TEXT
|
||||
expression = expression.copy()
|
||||
return self.func(
|
||||
"CONCAT",
|
||||
*[
|
||||
|
|
|
@ -35,7 +35,7 @@ class Databricks(Spark):
|
|||
exp.DatetimeSub: lambda self, e: self.func(
|
||||
"TIMESTAMPADD",
|
||||
e.text("unit"),
|
||||
exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
|
||||
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
|
||||
e.this,
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, e: self.func(
|
||||
|
@ -63,21 +63,14 @@ class Databricks(Spark):
|
|||
and kind.this in exp.DataType.INTEGER_TYPES
|
||||
):
|
||||
# only BIGINT generated identity constraints are supported
|
||||
expression = expression.copy()
|
||||
expression.set("kind", exp.DataType.build("bigint"))
|
||||
return super().columndef_sql(expression, sep)
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
expression = expression.copy()
|
||||
expression.set("this", True) # trigger ALWAYS in super class
|
||||
return super().generatedasidentitycolumnconstraint_sql(expression)
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
HEX_STRINGS = []
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**Spark.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
|
|
@ -315,11 +315,14 @@ class Dialect(metaclass=_Dialect):
|
|||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
|
||||
|
||||
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
|
||||
return self.generator(**opts).generate(expression)
|
||||
def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
|
||||
return self.generator(**opts).generate(expression, copy=copy)
|
||||
|
||||
def transpile(self, sql: str, **opts) -> t.List[str]:
|
||||
return [self.generate(expression, **opts) for expression in self.parse(sql)]
|
||||
return [
|
||||
self.generate(expression, copy=False, **opts) if expression else ""
|
||||
for expression in self.parse(sql)
|
||||
]
|
||||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
return self.tokenizer.tokenize(sql)
|
||||
|
@ -380,9 +383,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
|||
|
||||
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||
return self.like_sql(
|
||||
exp.Like(
|
||||
this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
|
||||
)
|
||||
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
|
||||
)
|
||||
|
||||
|
||||
|
@ -518,7 +519,6 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
|||
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
|
||||
|
||||
if has_schema and is_partitionable:
|
||||
expression = expression.copy()
|
||||
prop = expression.find(exp.PartitionedByProperty)
|
||||
if prop and prop.this and not isinstance(prop.this, exp.Schema):
|
||||
schema = expression.this
|
||||
|
@ -583,7 +583,7 @@ def date_add_interval_sql(
|
|||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
|
||||
interval = exp.Interval(this=expression.expression, unit=unit)
|
||||
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
|
||||
|
||||
return func
|
||||
|
@ -621,7 +621,6 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
|
|||
|
||||
|
||||
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Substring(
|
||||
this=expression.this, start=exp.Literal.number(1), length=expression.expression
|
||||
|
@ -630,7 +629,6 @@ def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
|||
|
||||
|
||||
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Substring(
|
||||
this=expression.this,
|
||||
|
@ -675,7 +673,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
|
|||
cond = expression.this.expressions[0]
|
||||
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
|
||||
|
||||
return self.func("sum", exp.func("if", cond.copy(), 1, 0))
|
||||
return self.func("sum", exp.func("if", cond, 1, 0))
|
||||
|
||||
|
||||
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
||||
|
@ -716,12 +714,10 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
|
||||
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
|
||||
expression = expression.copy()
|
||||
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
|
||||
|
||||
|
||||
def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
|
||||
expression = expression.copy()
|
||||
delim, *rest_args = expression.expressions
|
||||
return self.sql(
|
||||
reduce(
|
||||
|
@ -809,13 +805,6 @@ def isnull_to_is_null(args: t.List) -> exp.Expression:
|
|||
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
|
||||
|
||||
|
||||
def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
|
||||
if expression.expression.args.get("with"):
|
||||
expression = expression.copy()
|
||||
expression.set("with", expression.expression.args["with"].pop())
|
||||
return self.insert_sql(expression)
|
||||
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
|
|
|
@ -20,7 +20,9 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D
|
|||
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = exp.var(expression.text("unit").upper() or "DAY")
|
||||
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -147,7 +149,7 @@ class Drill(Dialect):
|
|||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
|
|
@ -36,14 +36,14 @@ from sqlglot.tokens import TokenType
|
|||
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
op = "+" if isinstance(expression, exp.DateAdd) else "-"
|
||||
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
|
||||
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||
|
||||
|
||||
# BigQuery -> DuckDB conversion for the DATE function
|
||||
|
@ -365,7 +365,7 @@ class DuckDB(Dialect):
|
|||
multiplier = 90
|
||||
|
||||
if multiplier:
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})"
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
|
||||
|
||||
return super().interval_sql(expression)
|
||||
|
||||
|
|
|
@ -53,8 +53,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
|||
|
||||
|
||||
def _create_sql(self, expression: exp.Create) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
# remove UNIQUE column constraints
|
||||
for constraint in expression.find_all(exp.UniqueColumnConstraint):
|
||||
if constraint.parent:
|
||||
|
@ -88,7 +86,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -
|
|||
if expression.expression.is_number:
|
||||
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
|
||||
else:
|
||||
modified_increment = expression.expression.copy()
|
||||
modified_increment = expression.expression
|
||||
if multiplier != 1:
|
||||
modified_increment = exp.Mul( # type: ignore
|
||||
this=modified_increment, expression=exp.Literal.number(multiplier)
|
||||
|
@ -229,6 +227,11 @@ class Hive(Dialect):
|
|||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ADD ARCHIVE": TokenType.COMMAND,
|
||||
|
@ -408,6 +411,7 @@ class Hive(Dialect):
|
|||
INDEX_ON = "ON TABLE"
|
||||
EXTRACT_ALLOWS_QUOTES = False
|
||||
NVL2_SUPPORTED = False
|
||||
SUPPORTS_NESTED_CTES = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -521,7 +525,10 @@ class Hive(Dialect):
|
|||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
|
||||
parent = expression.parent
|
||||
this = f"{this}:{expression_sql}" if expression_sql else this
|
||||
|
||||
if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem):
|
||||
# We need to produce SET key = value instead of SET ${key} = value
|
||||
|
@ -530,8 +537,6 @@ class Hive(Dialect):
|
|||
return f"${{{this}}}"
|
||||
|
||||
def schema_sql(self, expression: exp.Schema) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for ordered in expression.find_all(exp.Ordered):
|
||||
if ordered.args.get("desc") is False:
|
||||
ordered.set("desc", None)
|
||||
|
@ -539,8 +544,6 @@ class Hive(Dialect):
|
|||
return super().schema_sql(expression)
|
||||
|
||||
def constraint_sql(self, expression: exp.Constraint) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for prop in list(expression.find_all(exp.Properties)):
|
||||
prop.pop()
|
||||
|
||||
|
|
|
@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
|
|||
return f"STR_TO_DATE({concat}, '{date_format}')"
|
||||
|
||||
|
||||
def _str_to_date(args: t.List) -> exp.StrToDate:
|
||||
date_format = MySQL.format_time(seq_get(args, 1))
|
||||
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||
# All specifiers for time parts (as opposed to date parts)
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format
|
||||
TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"}
|
||||
|
||||
|
||||
def _has_time_specifier(date_format: str) -> bool:
|
||||
i = 0
|
||||
length = len(date_format)
|
||||
|
||||
while i < length:
|
||||
if date_format[i] == "%":
|
||||
i += 1
|
||||
if i < length and date_format[i] in TIME_SPECIFIERS:
|
||||
return True
|
||||
i += 1
|
||||
return False
|
||||
|
||||
|
||||
def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
|
||||
mysql_date_format = seq_get(args, 1)
|
||||
date_format = MySQL.format_time(mysql_date_format)
|
||||
this = seq_get(args, 0)
|
||||
|
||||
if mysql_date_format and _has_time_specifier(mysql_date_format.name):
|
||||
return exp.StrToTime(this=this, format=date_format)
|
||||
|
||||
return exp.StrToDate(this=this, format=date_format)
|
||||
|
||||
|
||||
def _str_to_date_sql(
|
||||
|
@ -93,7 +117,9 @@ def _date_add_sql(
|
|||
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
|
@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date(
|
|||
args: t.Tuple[str, ...] = ("this",),
|
||||
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.Func) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for arg_key in args:
|
||||
arg = expression.args.get(arg_key)
|
||||
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
|
||||
|
@ -629,6 +653,7 @@ class MySQL(Dialect):
|
|||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_full_outer_join,
|
||||
]
|
||||
),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
|
@ -728,7 +753,6 @@ class MySQL(Dialect):
|
|||
to = self.CAST_MAPPING.get(expression.to.this)
|
||||
|
||||
if to:
|
||||
expression = expression.copy()
|
||||
expression.to.set("this", to)
|
||||
return super().cast_sql(expression)
|
||||
|
||||
|
|
|
@ -43,8 +43,6 @@ DATE_DIFF_FACTOR = {
|
|||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
|
@ -96,7 +94,6 @@ def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
|
|||
|
||||
|
||||
def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
||||
order = ""
|
||||
|
@ -119,7 +116,6 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
|
|||
auto = expression.find(exp.AutoIncrementColumnConstraint)
|
||||
|
||||
if auto:
|
||||
expression = expression.copy()
|
||||
expression.args["constraints"].remove(auto.parent)
|
||||
kind = expression.args["kind"]
|
||||
|
||||
|
@ -134,7 +130,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
|
||||
kind = expression.args["kind"]
|
||||
kind = expression.args.get("kind")
|
||||
if not kind:
|
||||
return expression
|
||||
|
||||
if kind.this == exp.DataType.Type.SERIAL:
|
||||
data_type = exp.DataType(this=exp.DataType.Type.INT)
|
||||
|
@ -146,7 +144,6 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
|
|||
data_type = None
|
||||
|
||||
if data_type:
|
||||
expression = expression.copy()
|
||||
expression.args["kind"].replace(data_type)
|
||||
constraints = expression.args["constraints"]
|
||||
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
|
||||
|
@ -409,6 +406,7 @@ class Postgres(Dialect):
|
|||
exp.MapFromEntries: no_map_from_entries_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
|
@ -445,6 +443,7 @@ class Postgres(Dialect):
|
|||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
@ -452,7 +451,6 @@ class Postgres(Dialect):
|
|||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
|
||||
if isinstance(expression.this, exp.Array):
|
||||
expression = expression.copy()
|
||||
expression.set("this", exp.paren(expression.this, copy=False))
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
|
|
|
@ -36,7 +36,6 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
|
|||
|
||||
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
this=exp.Unnest(
|
||||
|
@ -72,7 +71,6 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
|
|||
for schema in expression.parent.find_all(exp.Schema):
|
||||
column_defs = schema.find_all(exp.ColumnDef)
|
||||
if column_defs and isinstance(schema.parent, exp.Property):
|
||||
expression = expression.copy()
|
||||
expression.expressions.extend(column_defs)
|
||||
|
||||
return self.schema_sql(expression)
|
||||
|
@ -407,12 +405,10 @@ class Presto(Dialect):
|
|||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type("timestamp"):
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.cast(end, to)
|
||||
end = exp.cast(end, target_type)
|
||||
else:
|
||||
start = exp.cast(start, to)
|
||||
start = exp.cast(start, target_type)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
|
@ -432,6 +428,5 @@ class Presto(Dialect):
|
|||
kind = expression.args["kind"]
|
||||
schema = expression.this
|
||||
if kind == "VIEW" and schema.expressions:
|
||||
expression = expression.copy()
|
||||
expression.this.set("expressions", None)
|
||||
return super().create_sql(expression)
|
||||
|
|
|
@ -27,6 +27,14 @@ def _parse_date_add(args: t.List) -> exp.DateAdd:
|
|||
)
|
||||
|
||||
|
||||
def _parse_datediff(args: t.List) -> exp.DateDiff:
|
||||
return exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
|
||||
|
||||
class Redshift(Postgres):
|
||||
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
@ -51,11 +59,9 @@ class Redshift(Postgres):
|
|||
),
|
||||
"DATEADD": _parse_date_add,
|
||||
"DATE_ADD": _parse_date_add,
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATEDIFF": _parse_datediff,
|
||||
"DATE_DIFF": _parse_datediff,
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -175,6 +181,7 @@ class Redshift(Postgres):
|
|||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -207,7 +214,6 @@ class Redshift(Postgres):
|
|||
`TEXT` to `VARCHAR`.
|
||||
"""
|
||||
if expression.is_type("text"):
|
||||
expression = expression.copy()
|
||||
expression.set("this", exp.DataType.Type.VARCHAR)
|
||||
precision = expression.args.get("expressions")
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def _check_int(s: str) -> bool:
|
|||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
||||
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
if len(args) == 2:
|
||||
first_arg, second_arg = args
|
||||
if second_arg.is_string:
|
||||
|
@ -60,8 +60,8 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
|
|||
# reduce it using `simplify_literals` first and then check if it's a Literal.
|
||||
first_arg = seq_get(args, 0)
|
||||
if not isinstance(simplify_literals(first_arg, root=True), Literal):
|
||||
# case: <variant_expr>
|
||||
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
|
||||
# case: <variant_expr> or other expressions such as columns
|
||||
return exp.TimeStrToTime.from_arg_list(args)
|
||||
|
||||
if first_arg.is_string:
|
||||
if _check_int(first_arg.this):
|
||||
|
@ -560,7 +560,6 @@ class Snowflake(Dialect):
|
|||
offset = expression.args.get("offset")
|
||||
if offset:
|
||||
if unnest_alias:
|
||||
expression = expression.copy()
|
||||
unnest_alias.append("columns", offset.pop())
|
||||
|
||||
selects.append("index")
|
||||
|
|
|
@ -63,6 +63,8 @@ class Spark(Spark2):
|
|||
return this
|
||||
|
||||
class Generator(Spark2.Generator):
|
||||
SUPPORTS_NESTED_CTES = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
|
||||
|
|
|
@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
|
|||
binary_from_function,
|
||||
format_time_lambda,
|
||||
is_parse_json,
|
||||
move_insert_cte_sql,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
|
@ -70,7 +69,9 @@ def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
|
|||
alias = pivot.args["alias"].pop()
|
||||
return exp.From(
|
||||
this=expression.this.replace(
|
||||
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
|
||||
exp.select("*")
|
||||
.from_(expression.this.copy(), copy=False)
|
||||
.subquery(alias=alias, copy=False)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -188,7 +189,6 @@ class Spark2(Hive):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.Insert: move_insert_cte_sql,
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
|
|
|
@ -50,7 +50,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
|
|||
else:
|
||||
for column in defs.values():
|
||||
auto_increment = None
|
||||
for constraint in column.constraints.copy():
|
||||
for constraint in column.constraints:
|
||||
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
|
||||
break
|
||||
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
|
||||
|
|
|
@ -38,12 +38,15 @@ class Teradata(Dialect):
|
|||
"^=": TokenType.NEQ,
|
||||
"BYTEINT": TokenType.SMALLINT,
|
||||
"COLLECT": TokenType.COMMAND,
|
||||
"DEL": TokenType.DELETE,
|
||||
"EQ": TokenType.EQ,
|
||||
"GE": TokenType.GTE,
|
||||
"GT": TokenType.GT,
|
||||
"HELP": TokenType.COMMAND,
|
||||
"INS": TokenType.INSERT,
|
||||
"LE": TokenType.LTE,
|
||||
"LT": TokenType.LT,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"MOD": TokenType.MOD,
|
||||
"NE": TokenType.NEQ,
|
||||
"NOT=": TokenType.NEQ,
|
||||
|
@ -51,6 +54,7 @@ class Teradata(Dialect):
|
|||
"SEL": TokenType.SELECT,
|
||||
"ST_GEOMETRY": TokenType.GEOMETRY,
|
||||
"TOP": TokenType.TOP,
|
||||
"UPD": TokenType.UPDATE,
|
||||
}
|
||||
|
||||
# Teradata does not support % as a modulo operator
|
||||
|
@ -181,6 +185,13 @@ class Teradata(Dialect):
|
|||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"):
|
||||
# We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>)
|
||||
expression.to.pop()
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
|
|
|
@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
|
|||
generatedasidentitycolumnconstraint_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
move_insert_cte_sql,
|
||||
parse_date_delta,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
|
@ -158,8 +157,6 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
|
|||
|
||||
|
||||
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
if distinct:
|
||||
|
@ -246,6 +243,7 @@ class TSQL(Dialect):
|
|||
"MMM": "%b",
|
||||
"MM": "%m",
|
||||
"M": "%-m",
|
||||
"dddd": "%A",
|
||||
"dd": "%d",
|
||||
"d": "%-d",
|
||||
"HH": "%H",
|
||||
|
@ -596,6 +594,8 @@ class TSQL(Dialect):
|
|||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
LIMIT_FETCH = "FETCH"
|
||||
COMPUTED_COLUMN_WITH_TYPE = False
|
||||
SUPPORTS_NESTED_CTES = False
|
||||
CTE_RECURSIVE_KEYWORD_REQUIRED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -622,7 +622,6 @@ class TSQL(Dialect):
|
|||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.Insert: move_insert_cte_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
|
@ -685,7 +684,6 @@ class TSQL(Dialect):
|
|||
return sql
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
expression = expression.copy()
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
exists = expression.args.pop("exists", None)
|
||||
sql = super().create_sql(expression)
|
||||
|
@ -714,7 +712,7 @@ class TSQL(Dialect):
|
|||
elif expression.args.get("replace"):
|
||||
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
|
||||
|
||||
return sql
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
|
|
@ -2145,6 +2145,22 @@ class PartitionedByProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/current/sql-createtable.html
|
||||
class PartitionBoundSpec(Expression):
|
||||
# this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...)
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"expression": False,
|
||||
"from_expressions": False,
|
||||
"to_expressions": False,
|
||||
}
|
||||
|
||||
|
||||
class PartitionedOfProperty(Property):
|
||||
# this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class RemoteWithConnectionModelProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -2486,6 +2502,7 @@ class Table(Expression):
|
|||
"format": False,
|
||||
"pattern": False,
|
||||
"index": False,
|
||||
"ordinality": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2649,11 +2666,7 @@ class Update(Expression):
|
|||
|
||||
|
||||
class Values(UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"ordinality": False,
|
||||
"alias": False,
|
||||
}
|
||||
arg_types = {"expressions": True, "alias": False}
|
||||
|
||||
|
||||
class Var(Expression):
|
||||
|
@ -3501,7 +3514,7 @@ class Star(Expression):
|
|||
|
||||
|
||||
class Parameter(Condition):
|
||||
arg_types = {"this": True, "wrapped": False}
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class SessionParameter(Condition):
|
||||
|
@ -5036,7 +5049,7 @@ class FromBase(Func):
|
|||
|
||||
|
||||
class Struct(Func):
|
||||
arg_types = {"expressions": True}
|
||||
arg_types = {"expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
|
@ -5171,7 +5184,7 @@ class Use(Expression):
|
|||
|
||||
|
||||
class Merge(Expression):
|
||||
arg_types = {"this": True, "using": True, "on": True, "expressions": True}
|
||||
arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False}
|
||||
|
||||
|
||||
class When(Func):
|
||||
|
@ -5459,7 +5472,12 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
|
|||
|
||||
|
||||
def union(
|
||||
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Union:
|
||||
"""
|
||||
Initializes a syntax tree from one UNION expression.
|
||||
|
@ -5475,19 +5493,25 @@ def union(
|
|||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether or not to copy the expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The new Union instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Union(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def intersect(
|
||||
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Intersect:
|
||||
"""
|
||||
Initializes a syntax tree from one INTERSECT expression.
|
||||
|
@ -5503,19 +5527,25 @@ def intersect(
|
|||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether or not to copy the expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The new Intersect instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Intersect(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def except_(
|
||||
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
|
||||
left: ExpOrStr,
|
||||
right: ExpOrStr,
|
||||
distinct: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Except:
|
||||
"""
|
||||
Initializes a syntax tree from one EXCEPT expression.
|
||||
|
@ -5531,13 +5561,14 @@ def except_(
|
|||
If an `Expression` instance is passed, it will be used as-is.
|
||||
distinct: set the DISTINCT flag if and only if this is true.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether or not to copy the expression.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The new Except instance.
|
||||
"""
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
|
||||
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
|
||||
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
|
||||
|
||||
return Except(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
@ -5861,7 +5892,7 @@ def to_identifier(name, quoted=None, copy=True):
|
|||
Args:
|
||||
name: The name to turn into an identifier.
|
||||
quoted: Whether or not force quote the identifier.
|
||||
copy: Whether or not to copy a passed in Identefier node.
|
||||
copy: Whether or not to copy name if it's an Identifier.
|
||||
|
||||
Returns:
|
||||
The identifier ast node.
|
||||
|
@ -5882,6 +5913,25 @@ def to_identifier(name, quoted=None, copy=True):
|
|||
return identifier
|
||||
|
||||
|
||||
def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
|
||||
"""
|
||||
Parses a given string into an identifier.
|
||||
|
||||
Args:
|
||||
name: The name to parse into an identifier.
|
||||
dialect: The dialect to parse against.
|
||||
|
||||
Returns:
|
||||
The identifier ast node.
|
||||
"""
|
||||
try:
|
||||
expression = maybe_parse(name, dialect=dialect, into=Identifier)
|
||||
except ParseError:
|
||||
expression = to_identifier(name)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
|
||||
|
||||
|
||||
|
|
|
@ -230,6 +230,12 @@ class Generator:
|
|||
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = False
|
||||
|
||||
# Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
|
||||
SUPPORTS_NESTED_CTES = True
|
||||
|
||||
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
|
||||
CTE_RECURSIVE_KEYWORD_REQUIRED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -304,6 +310,7 @@ class Generator:
|
|||
exp.Order: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
|
||||
exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Property: exp.Properties.Location.POST_WITH,
|
||||
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -407,7 +414,6 @@ class Generator:
|
|||
"unsupported_messages",
|
||||
"_escaped_quote_end",
|
||||
"_escaped_identifier_end",
|
||||
"_cache",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -447,30 +453,38 @@ class Generator:
|
|||
self._escaped_identifier_end: str = (
|
||||
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
|
||||
)
|
||||
self._cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
expression: t.Optional[exp.Expression],
|
||||
cache: t.Optional[t.Dict[int, str]] = None,
|
||||
) -> str:
|
||||
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
|
||||
"""
|
||||
Generates the SQL string corresponding to the given syntax tree.
|
||||
|
||||
Args:
|
||||
expression: The syntax tree.
|
||||
cache: An optional sql string cache. This leverages the hash of an Expression
|
||||
which can be slow to compute, so only use it if you set _hash on each node.
|
||||
copy: Whether or not to copy the expression. The generator performs mutations so
|
||||
it is safer to copy.
|
||||
|
||||
Returns:
|
||||
The SQL string corresponding to `expression`.
|
||||
"""
|
||||
if cache is not None:
|
||||
self._cache = cache
|
||||
if copy:
|
||||
expression = expression.copy()
|
||||
|
||||
# Some dialects only support CTEs at the top level expression, so we need to bubble up nested
|
||||
# CTEs to that level in order to produce a syntactically valid expression. This transformation
|
||||
# happens here to minimize code duplication, since many expressions support CTEs.
|
||||
if (
|
||||
not self.SUPPORTS_NESTED_CTES
|
||||
and isinstance(expression, exp.Expression)
|
||||
and not expression.parent
|
||||
and "with" in expression.arg_types
|
||||
and any(node.parent is not expression for node in expression.find_all(exp.With))
|
||||
):
|
||||
from sqlglot.transforms import move_ctes_to_top_level
|
||||
|
||||
expression = move_ctes_to_top_level(expression)
|
||||
|
||||
self.unsupported_messages = []
|
||||
sql = self.sql(expression).strip()
|
||||
self._cache = None
|
||||
|
||||
if self.unsupported_level == ErrorLevel.IGNORE:
|
||||
return sql
|
||||
|
@ -595,12 +609,6 @@ class Generator:
|
|||
return self.sql(value)
|
||||
return ""
|
||||
|
||||
if self._cache is not None:
|
||||
expression_id = hash(expression)
|
||||
|
||||
if expression_id in self._cache:
|
||||
return self._cache[expression_id]
|
||||
|
||||
transform = self.TRANSFORMS.get(expression.__class__)
|
||||
|
||||
if callable(transform):
|
||||
|
@ -621,11 +629,7 @@ class Generator:
|
|||
else:
|
||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||
|
||||
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
|
||||
|
||||
if self._cache is not None:
|
||||
self._cache[expression_id] = sql
|
||||
return sql
|
||||
return self.maybe_comment(sql, expression) if self.comments and comment else sql
|
||||
|
||||
def uncache_sql(self, expression: exp.Uncache) -> str:
|
||||
table = self.sql(expression, "this")
|
||||
|
@ -879,7 +883,11 @@ class Generator:
|
|||
|
||||
def with_sql(self, expression: exp.With) -> str:
|
||||
sql = self.expressions(expression, flat=True)
|
||||
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
|
||||
recursive = (
|
||||
"RECURSIVE "
|
||||
if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive")
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"WITH {recursive}{sql}"
|
||||
|
||||
|
@ -1022,7 +1030,7 @@ class Generator:
|
|||
where = self.sql(expression, "expression").strip()
|
||||
return f"{this} FILTER({where})"
|
||||
|
||||
agg = expression.this.copy()
|
||||
agg = expression.this
|
||||
agg_arg = agg.this
|
||||
cond = expression.expression.this
|
||||
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
|
||||
|
@ -1088,9 +1096,9 @@ class Generator:
|
|||
for p in expression.expressions:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc == exp.Properties.Location.POST_WITH:
|
||||
with_properties.append(p.copy())
|
||||
with_properties.append(p)
|
||||
elif p_loc == exp.Properties.Location.POST_SCHEMA:
|
||||
root_properties.append(p.copy())
|
||||
root_properties.append(p)
|
||||
|
||||
return self.root_properties(
|
||||
exp.Properties(expressions=root_properties)
|
||||
|
@ -1124,7 +1132,7 @@ class Generator:
|
|||
for p in properties.expressions:
|
||||
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||
if p_loc != exp.Properties.Location.UNSUPPORTED:
|
||||
properties_locs[p_loc].append(p.copy())
|
||||
properties_locs[p_loc].append(p)
|
||||
else:
|
||||
self.unsupported(f"Unsupported property {p.key}")
|
||||
|
||||
|
@ -1238,6 +1246,29 @@ class Generator:
|
|||
for_ = " FOR NONE"
|
||||
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
|
||||
|
||||
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
|
||||
if isinstance(expression.this, list):
|
||||
return f"IN ({self.expressions(expression, key='this', flat=True)})"
|
||||
if expression.this:
|
||||
modulus = self.sql(expression, "this")
|
||||
remainder = self.sql(expression, "expression")
|
||||
return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"
|
||||
|
||||
from_expressions = self.expressions(expression, key="from_expressions", flat=True)
|
||||
to_expressions = self.expressions(expression, key="to_expressions", flat=True)
|
||||
return f"FROM ({from_expressions}) TO ({to_expressions})"
|
||||
|
||||
def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
for_values_or_default = expression.expression
|
||||
if isinstance(for_values_or_default, exp.PartitionBoundSpec):
|
||||
for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
|
||||
else:
|
||||
for_values_or_default = " DEFAULT"
|
||||
|
||||
return f"PARTITION OF {this}{for_values_or_default}"
|
||||
|
||||
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
|
||||
kind = expression.args.get("kind")
|
||||
this = f" {self.sql(expression, 'this')}" if expression.this else ""
|
||||
|
@ -1385,7 +1416,12 @@ class Generator:
|
|||
index = self.sql(expression, "index")
|
||||
index = f" AT {index}" if index else ""
|
||||
|
||||
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
|
||||
ordinality = expression.args.get("ordinality") or ""
|
||||
if ordinality:
|
||||
ordinality = f" WITH ORDINALITY{alias}"
|
||||
alias = ""
|
||||
|
||||
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
|
@ -1489,7 +1525,6 @@ class Generator:
|
|||
return f"{values} AS {alias}" if alias else values
|
||||
|
||||
# Converts `VALUES...` expression into a series of select unions.
|
||||
expression = expression.copy()
|
||||
alias_node = expression.args.get("alias")
|
||||
column_names = alias_node and alias_node.columns
|
||||
|
||||
|
@ -1972,8 +2007,7 @@ class Generator:
|
|||
|
||||
if self.UNNEST_WITH_ORDINALITY:
|
||||
if alias and isinstance(offset, exp.Expression):
|
||||
alias = alias.copy()
|
||||
alias.append("columns", offset.copy())
|
||||
alias.append("columns", offset)
|
||||
|
||||
if alias and self.UNNEST_COLUMN_ONLY:
|
||||
columns = alias.columns
|
||||
|
@ -2138,7 +2172,6 @@ class Generator:
|
|||
return f"PRIMARY KEY ({expressions}){options}"
|
||||
|
||||
def if_sql(self, expression: exp.If) -> str:
|
||||
expression = expression.copy()
|
||||
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
|
||||
|
||||
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
|
||||
|
@ -2367,7 +2400,9 @@ class Generator:
|
|||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
format_sql = self.sql(expression, "format")
|
||||
format_sql = f" FORMAT {format_sql}" if format_sql else ""
|
||||
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
|
||||
to_sql = self.sql(expression, "to")
|
||||
to_sql = f" {to_sql}" if to_sql else ""
|
||||
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
|
||||
|
||||
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
|
||||
zone = self.sql(expression, "this")
|
||||
|
@ -2510,7 +2545,7 @@ class Generator:
|
|||
def intdiv_sql(self, expression: exp.IntDiv) -> str:
|
||||
return self.sql(
|
||||
exp.Cast(
|
||||
this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
|
||||
this=exp.Div(this=expression.this, expression=expression.expression),
|
||||
to=exp.DataType(this=exp.DataType.Type.INT),
|
||||
)
|
||||
)
|
||||
|
@ -2779,7 +2814,6 @@ class Generator:
|
|||
hints = table.args.get("hints")
|
||||
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
|
||||
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
|
||||
table = table.copy()
|
||||
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
|
||||
|
||||
this = self.sql(table)
|
||||
|
@ -2787,7 +2821,9 @@ class Generator:
|
|||
on = f"ON {self.sql(expression, 'on')}"
|
||||
expressions = self.expressions(expression, sep=" ")
|
||||
|
||||
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
|
||||
return self.prepend_ctes(
|
||||
expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
|
||||
)
|
||||
|
||||
def tochar_sql(self, expression: exp.ToChar) -> str:
|
||||
if expression.args.get("format"):
|
||||
|
@ -2896,12 +2932,12 @@ class Generator:
|
|||
|
||||
case = exp.Case().when(
|
||||
expression.this.is_(exp.null()).not_(copy=False),
|
||||
expression.args["true"].copy(),
|
||||
expression.args["true"],
|
||||
copy=False,
|
||||
)
|
||||
else_cond = expression.args.get("false")
|
||||
if else_cond:
|
||||
case.else_(else_cond.copy(), copy=False)
|
||||
case.else_(else_cond, copy=False)
|
||||
|
||||
return self.sql(case)
|
||||
|
||||
|
@ -2931,15 +2967,6 @@ class Generator:
|
|||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
expression = simplify(expression.copy())
|
||||
expression = simplify(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
) -> t.Callable[[exp.Expression], str]:
|
||||
"""Returns a cached generator."""
|
||||
cache = {} if cache is None else cache
|
||||
generator = Generator(normalize=True, identify="safe")
|
||||
return lambda e: generator.generate(e, cache)
|
||||
|
|
|
@ -184,9 +184,7 @@ def apply_index_offset(
|
|||
annotate_types(expression)
|
||||
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression = simplify(
|
||||
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
|
||||
)
|
||||
expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
|
|
@ -4,7 +4,6 @@ import logging
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
||||
|
@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
generate = cached_generator()
|
||||
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
|
@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
|
||||
try:
|
||||
node = node.replace(
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
|
||||
)
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
|
@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
|
|||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance, generate):
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
|
@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
|
|||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
|
@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
|
|||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func, generate):
|
||||
def _distribute(a, b, from_func, to_func):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
uniq_sort(flatten(from_func(c, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(c, b.right)), generate),
|
||||
uniq_sort(flatten(from_func(c, b.left))),
|
||||
uniq_sort(flatten(from_func(c, b.right))),
|
||||
copy=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(a, b.right)), generate),
|
||||
uniq_sort(flatten(from_func(a, b.left))),
|
||||
uniq_sort(flatten(from_func(a, b.right))),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
|
|||
The transformed expression.
|
||||
"""
|
||||
if isinstance(expression, str):
|
||||
expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def qualify_tables(
|
|||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
if not source.args.get("catalog"):
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not source.alias:
|
||||
|
|
|
@ -7,8 +7,7 @@ from decimal import Decimal
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
|
@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
generate = cached_generator()
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
|
@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
|
|||
# Pre-order transformations
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, generate, root)
|
||||
node = uniq_sort(node, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
node = simplify_conditionals(node)
|
||||
|
@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression, generate, root=True):
|
||||
def uniq_sort(expression, root=True):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
|
@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
|
|||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {generate(e): e for e in flattened}
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
|
@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
|
|||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def gen(expression: t.Any) -> str:
|
||||
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
|
||||
|
||||
Sorting and deduping sql is a necessary step for optimization. Calling the actual
|
||||
generator is expensive so we have a bare minimum sql generator here.
|
||||
"""
|
||||
if expression is None:
|
||||
return "_"
|
||||
if is_iterable(expression):
|
||||
return ",".join(gen(e) for e in expression)
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return str(expression)
|
||||
|
||||
etype = type(expression)
|
||||
if etype in GEN_MAP:
|
||||
return GEN_MAP[etype](expression)
|
||||
return f"{expression.key} {gen(expression.args.values())}"
|
||||
|
||||
|
||||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
|
||||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.DPipe: lambda e: _binary(e, "||"),
|
||||
exp.SafeDPipe: lambda e: _binary(e, "||"),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
|
||||
exp.ILike: lambda e: _binary(e, "ILIKE"),
|
||||
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
|
||||
exp.Is: lambda e: _binary(e, "IS"),
|
||||
exp.Like: lambda e: _binary(e, "LIKE"),
|
||||
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
|
||||
exp.LT: lambda e: _binary(e, "<"),
|
||||
exp.LTE: lambda e: _binary(e, "<="),
|
||||
exp.Mod: lambda e: _binary(e, "%"),
|
||||
exp.Mul: lambda e: _binary(e, "*"),
|
||||
exp.Neg: lambda e: _unary(e, "-"),
|
||||
exp.NEQ: lambda e: _binary(e, "<>"),
|
||||
exp.Not: lambda e: _unary(e, "NOT"),
|
||||
exp.Null: lambda e: "NULL",
|
||||
exp.Or: lambda e: _binary(e, "OR"),
|
||||
exp.Paren: lambda e: f"({gen(e.this)})",
|
||||
exp.Sub: lambda e: _binary(e, "-"),
|
||||
exp.Subquery: lambda e: f"({gen(e.args.values())})",
|
||||
exp.Table: lambda e: gen(e.args.values()),
|
||||
exp.Var: lambda e: e.name,
|
||||
}
|
||||
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
|
||||
|
||||
def _unary(e: exp.Unary, op: str) -> str:
|
||||
return f"{op} {gen(e.this)}"
|
||||
|
|
|
@ -674,6 +674,7 @@ class Parser(metaclass=_Parser):
|
|||
"ON": lambda self: self._parse_on_property(),
|
||||
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
|
||||
"OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
|
||||
"PARTITION": lambda self: self._parse_partitioned_of(),
|
||||
"PARTITION BY": lambda self: self._parse_partitioned_by(),
|
||||
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
|
||||
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
|
||||
|
@ -1743,6 +1744,58 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_csv(self._parse_conjunction)
|
||||
return []
|
||||
|
||||
def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec:
|
||||
def _parse_partition_bound_expr() -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("MINVALUE"):
|
||||
return exp.var("MINVALUE")
|
||||
if self._match_text_seq("MAXVALUE"):
|
||||
return exp.var("MAXVALUE")
|
||||
return self._parse_bitwise()
|
||||
|
||||
this: t.Optional[exp.Expression | t.List[exp.Expression]] = None
|
||||
expression = None
|
||||
from_expressions = None
|
||||
to_expressions = None
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
this = self._parse_wrapped_csv(self._parse_bitwise)
|
||||
elif self._match(TokenType.FROM):
|
||||
from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
|
||||
self._match_text_seq("TO")
|
||||
to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
|
||||
elif self._match_text_seq("WITH", "(", "MODULUS"):
|
||||
this = self._parse_number()
|
||||
self._match_text_seq(",", "REMAINDER")
|
||||
expression = self._parse_number()
|
||||
self._match_r_paren()
|
||||
else:
|
||||
self.raise_error("Failed to parse partition bound spec.")
|
||||
|
||||
return self.expression(
|
||||
exp.PartitionBoundSpec,
|
||||
this=this,
|
||||
expression=expression,
|
||||
from_expressions=from_expressions,
|
||||
to_expressions=to_expressions,
|
||||
)
|
||||
|
||||
# https://www.postgresql.org/docs/current/sql-createtable.html
|
||||
def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]:
|
||||
if not self._match_text_seq("OF"):
|
||||
self._retreat(self._index - 1)
|
||||
return None
|
||||
|
||||
this = self._parse_table(schema=True)
|
||||
|
||||
if self._match(TokenType.DEFAULT):
|
||||
expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT")
|
||||
elif self._match_text_seq("FOR", "VALUES"):
|
||||
expression = self._parse_partition_bound_spec()
|
||||
else:
|
||||
self.raise_error("Expecting either DEFAULT or FOR VALUES clause.")
|
||||
|
||||
return self.expression(exp.PartitionedOfProperty, this=this, expression=expression)
|
||||
|
||||
def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
|
@ -2682,6 +2735,10 @@ class Parser(metaclass=_Parser):
|
|||
for join in iter(self._parse_join, None):
|
||||
this.append("joins", join)
|
||||
|
||||
if self._match_pair(TokenType.WITH, TokenType.ORDINALITY):
|
||||
this.set("ordinality", True)
|
||||
this.set("alias", self._parse_table_alias())
|
||||
|
||||
return this
|
||||
|
||||
def _parse_version(self) -> t.Optional[exp.Version]:
|
||||
|
@ -4189,17 +4246,12 @@ class Parser(metaclass=_Parser):
|
|||
fmt = None
|
||||
to = self._parse_types()
|
||||
|
||||
if not to:
|
||||
self.raise_error("Expected TYPE after CAST")
|
||||
elif isinstance(to, exp.Identifier):
|
||||
to = exp.DataType.build(to.name, udt=True)
|
||||
elif to.this == exp.DataType.Type.CHAR:
|
||||
if self._match(TokenType.CHARACTER_SET):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
|
||||
elif self._match(TokenType.FORMAT):
|
||||
if self._match(TokenType.FORMAT):
|
||||
fmt_string = self._parse_string()
|
||||
fmt = self._parse_at_time_zone(fmt_string)
|
||||
|
||||
if not to:
|
||||
to = exp.DataType.build(exp.DataType.Type.UNKNOWN)
|
||||
if to.this in exp.DataType.TEMPORAL_TYPES:
|
||||
this = self.expression(
|
||||
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
|
||||
|
@ -4215,8 +4267,14 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
|
||||
this.set("zone", fmt.args["zone"])
|
||||
|
||||
return this
|
||||
elif not to:
|
||||
self.raise_error("Expected TYPE after CAST")
|
||||
elif isinstance(to, exp.Identifier):
|
||||
to = exp.DataType.build(to.name, udt=True)
|
||||
elif to.this == exp.DataType.Type.CHAR:
|
||||
if self._match(TokenType.CHARACTER_SET):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
|
||||
|
||||
return self.expression(
|
||||
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
|
||||
|
@ -4789,10 +4847,17 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_placeholder()
|
||||
|
||||
def _parse_parameter(self) -> exp.Parameter:
|
||||
wrapped = self._match(TokenType.L_BRACE)
|
||||
this = self._parse_var() or self._parse_identifier() or self._parse_primary()
|
||||
def _parse_parameter_part() -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True)
|
||||
)
|
||||
|
||||
self._match(TokenType.L_BRACE)
|
||||
this = _parse_parameter_part()
|
||||
expression = self._match(TokenType.COLON) and _parse_parameter_part()
|
||||
self._match(TokenType.R_BRACE)
|
||||
return self.expression(exp.Parameter, this=this, wrapped=wrapped)
|
||||
|
||||
return self.expression(exp.Parameter, this=this, expression=expression)
|
||||
|
||||
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.PLACEHOLDER_PARSERS):
|
||||
|
|
|
@ -3,10 +3,9 @@ from __future__ import annotations
|
|||
import abc
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.errors import ParseError, SchemaError
|
||||
from sqlglot.errors import SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
@ -448,19 +447,16 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
|
||||
|
||||
def normalize_name(
|
||||
name: str | exp.Identifier,
|
||||
identifier: str | exp.Identifier,
|
||||
dialect: DialectType = None,
|
||||
is_table: bool = False,
|
||||
normalize: t.Optional[bool] = True,
|
||||
) -> str:
|
||||
try:
|
||||
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
|
||||
except ParseError:
|
||||
return name if isinstance(name, str) else name.name
|
||||
if isinstance(identifier, str):
|
||||
identifier = exp.parse_identifier(identifier, dialect=dialect)
|
||||
|
||||
name = identifier.name
|
||||
if not normalize:
|
||||
return name
|
||||
return identifier.name
|
||||
|
||||
# This can be useful for normalize_identifier
|
||||
identifier.meta["is_table"] = is_table
|
||||
|
|
|
@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
order = expression.args.get("order")
|
||||
|
||||
if order:
|
||||
window.set("order", order.pop().copy())
|
||||
window.set("order", order.pop())
|
||||
else:
|
||||
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
|
||||
|
||||
|
@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
expression.select(window, copy=False)
|
||||
|
||||
return (
|
||||
exp.select(*outer_selects)
|
||||
.from_(expression.subquery("_t"))
|
||||
.where(exp.column(row_number).eq(1))
|
||||
exp.select(*outer_selects, copy=False)
|
||||
.from_(expression.subquery("_t", copy=False), copy=False)
|
||||
.where(exp.column(row_number).eq(1), copy=False)
|
||||
)
|
||||
|
||||
return expression
|
||||
|
@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
elif expr.name not in expression.named_selects:
|
||||
expression.select(expr.copy(), copy=False)
|
||||
|
||||
return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
|
||||
return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
|
||||
qualify_filters, copy=False
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
)
|
||||
|
||||
# we use list here because expression.selects is mutated inside the loop
|
||||
for select in expression.selects.copy():
|
||||
for select in list(expression.selects):
|
||||
explode = select.find(exp.Explode)
|
||||
|
||||
if explode:
|
||||
|
@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Converts a query with a FULL OUTER join to a union of identical queries that
|
||||
use LEFT/RIGHT OUTER joins instead. This transformation currently only works
|
||||
for queries that have a single FULL OUTER join.
|
||||
"""
|
||||
if isinstance(expression, exp.Select):
|
||||
full_outer_joins = [
|
||||
(index, join)
|
||||
for index, join in enumerate(expression.args.get("joins") or [])
|
||||
if join.side == "FULL" and join.kind == "OUTER"
|
||||
]
|
||||
|
||||
if len(full_outer_joins) == 1:
|
||||
expression_copy = expression.copy()
|
||||
index, full_outer_join = full_outer_joins[0]
|
||||
full_outer_join.set("side", "left")
|
||||
expression_copy.args["joins"][index].set("side", "right")
|
||||
|
||||
return exp.union(expression, expression_copy, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
|
||||
defined at the top-level, so for example queries like:
|
||||
|
||||
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
|
||||
|
||||
are invalid in those dialects. This transformation can be used to ensure all CTEs are
|
||||
moved to the top level so that the final SQL code is valid from a syntax standpoint.
|
||||
|
||||
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
|
||||
"""
|
||||
top_level_with = expression.args.get("with")
|
||||
for node in expression.find_all(exp.With):
|
||||
if node.parent is expression:
|
||||
continue
|
||||
|
||||
inner_with = node.pop()
|
||||
if not top_level_with:
|
||||
top_level_with = inner_with
|
||||
expression.set("with", top_level_with)
|
||||
else:
|
||||
if inner_with.recursive:
|
||||
top_level_with.set("recursive", True)
|
||||
|
||||
top_level_with.expressions.extend(inner_with.expressions)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
@ -392,7 +448,7 @@ def preprocess(
|
|||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
|
||||
expression = transforms[0](expression.copy())
|
||||
expression = transforms[0](expression)
|
||||
for t in transforms[1:]:
|
||||
expression = t(expression)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue