1
0
Fork 0

Merging upstream version 19.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:09 +01:00
parent 348b067e1b
commit 89acb78953
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
91 changed files with 45416 additions and 43096 deletions

View file

@ -158,6 +158,6 @@ def transpile(
"""
write = (read if write is None else write) if identity else write
return [
Dialect.get_or_raise(write)().generate(expression, **opts)
Dialect.get_or_raise(write)().generate(expression, copy=False, **opts) if expression else ""
for expression in parse(sql, read, error_level=error_level)
]

View file

@ -69,7 +69,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
@ -699,6 +698,5 @@ class BigQuery(Dialect):
def version_sql(self, expression: exp.Version) -> str:
if expression.name == "TIMESTAMP":
expression = expression.copy()
expression.set("this", "SYSTEM_TIME")
return super().version_sql(expression)

View file

@ -461,7 +461,6 @@ class ClickHouse(Dialect):
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
expression = expression.copy()
return self.func(
"CONCAT",
*[

View file

@ -35,7 +35,7 @@ class Databricks(Spark):
exp.DatetimeSub: lambda self, e: self.func(
"TIMESTAMPADD",
e.text("unit"),
exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
e.this,
),
exp.DatetimeDiff: lambda self, e: self.func(
@ -63,21 +63,14 @@ class Databricks(Spark):
and kind.this in exp.DataType.INTEGER_TYPES
):
# only BIGINT generated identity constraints are supported
expression = expression.copy()
expression.set("kind", exp.DataType.build("bigint"))
return super().columndef_sql(expression, sep)
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
expression = expression.copy()
expression.set("this", True) # trigger ALWAYS in super class
return super().generatedasidentitycolumnconstraint_sql(expression)
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}

View file

@ -315,11 +315,14 @@ class Dialect(metaclass=_Dialect):
) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str:
return self.generator(**opts).generate(expression, copy=copy)
def transpile(self, sql: str, **opts) -> t.List[str]:
return [self.generate(expression, **opts) for expression in self.parse(sql)]
return [
self.generate(expression, copy=False, **opts) if expression else ""
for expression in self.parse(sql)
]
def tokenize(self, sql: str) -> t.List[Token]:
return self.tokenizer.tokenize(sql)
@ -380,9 +383,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy()
)
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
)
@ -518,7 +519,6 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
@ -583,7 +583,7 @@ def date_add_interval_sql(
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
interval = exp.Interval(this=expression.expression, unit=unit)
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
@ -621,7 +621,6 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this, start=exp.Literal.number(1), length=expression.expression
@ -630,7 +629,6 @@ def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this,
@ -675,7 +673,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this.expressions[0]
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
return self.func("sum", exp.func("if", cond.copy(), 1, 0))
return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql(self: Generator, expression: exp.Trim) -> str:
@ -716,12 +714,10 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
expression = expression.copy()
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str:
expression = expression.copy()
delim, *rest_args = expression.expressions
return self.sql(
reduce(
@ -809,13 +805,6 @@ def isnull_to_is_null(args: t.List) -> exp.Expression:
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
if expression.expression.args.get("with"):
expression = expression.copy()
expression.set("with", expression.expression.args["with"].pop())
return self.insert_sql(expression)
def generatedasidentitycolumnconstraint_sql(
self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:

View file

@ -20,7 +20,9 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func
@ -147,7 +149,7 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}

View file

@ -36,14 +36,14 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
# BigQuery -> DuckDB conversion for the DATE function
@ -365,7 +365,7 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})"
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
return super().interval_sql(expression)

View file

@ -53,8 +53,6 @@ DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _create_sql(self, expression: exp.Create) -> str:
expression = expression.copy()
# remove UNIQUE column constraints
for constraint in expression.find_all(exp.UniqueColumnConstraint):
if constraint.parent:
@ -88,7 +86,7 @@ def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
else:
modified_increment = expression.expression.copy()
modified_increment = expression.expression
if multiplier != 1:
modified_increment = exp.Mul( # type: ignore
this=modified_increment, expression=exp.Literal.number(multiplier)
@ -229,6 +227,11 @@ class Hive(Dialect):
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
@ -408,6 +411,7 @@ class Hive(Dialect):
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
SUPPORTS_NESTED_CTES = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -521,7 +525,10 @@ class Hive(Dialect):
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
parent = expression.parent
this = f"{this}:{expression_sql}" if expression_sql else this
if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem):
# We need to produce SET key = value instead of SET ${key} = value
@ -530,8 +537,6 @@ class Hive(Dialect):
return f"${{{this}}}"
def schema_sql(self, expression: exp.Schema) -> str:
expression = expression.copy()
for ordered in expression.find_all(exp.Ordered):
if ordered.args.get("desc") is False:
ordered.set("desc", None)
@ -539,8 +544,6 @@ class Hive(Dialect):
return super().schema_sql(expression)
def constraint_sql(self, expression: exp.Constraint) -> str:
expression = expression.copy()
for prop in list(expression.find_all(exp.Properties)):
prop.pop()

View file

@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args: t.List) -> exp.StrToDate:
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
# All specifiers for time parts (as opposed to date parts)
# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format
TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"}
def _has_time_specifier(date_format: str) -> bool:
i = 0
length = len(date_format)
while i < length:
if date_format[i] == "%":
i += 1
if i < length and date_format[i] in TIME_SPECIFIERS:
return True
i += 1
return False
def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
mysql_date_format = seq_get(args, 1)
date_format = MySQL.format_time(mysql_date_format)
this = seq_get(args, 0)
if mysql_date_format and _has_time_specifier(mysql_date_format.name):
return exp.StrToTime(this=this, format=date_format)
return exp.StrToDate(this=this, format=date_format)
def _str_to_date_sql(
@ -93,7 +117,9 @@ def _date_add_sql(
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func
@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date(
args: t.Tuple[str, ...] = ("this",),
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
def func(self: MySQL.Generator, expression: exp.Func) -> str:
expression = expression.copy()
for arg_key in args:
arg = expression.args.get(arg_key)
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
@ -629,6 +653,7 @@ class MySQL(Dialect):
transforms.eliminate_distinct_on,
transforms.eliminate_semi_and_anti_joins,
transforms.eliminate_qualify,
transforms.eliminate_full_outer_join,
]
),
exp.StrPosition: strposition_to_locate_sql,
@ -728,7 +753,6 @@ class MySQL(Dialect):
to = self.CAST_MAPPING.get(expression.to.this)
if to:
expression = expression.copy()
expression.to.set("this", to)
return super().cast_sql(expression)

View file

@ -43,8 +43,6 @@ DATE_DIFF_FACTOR = {
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
expression = expression.copy()
this = self.sql(expression, "this")
unit = expression.args.get("unit")
@ -96,7 +94,6 @@ def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
order = ""
@ -119,7 +116,6 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
if auto:
expression = expression.copy()
expression.args["constraints"].remove(auto.parent)
kind = expression.args["kind"]
@ -134,7 +130,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
kind = expression.args["kind"]
kind = expression.args.get("kind")
if not kind:
return expression
if kind.this == exp.DataType.Type.SERIAL:
data_type = exp.DataType(this=exp.DataType.Type.INT)
@ -146,7 +144,6 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
data_type = None
if data_type:
expression = expression.copy()
expression.args["kind"].replace(data_type)
constraints = expression.args["constraints"]
generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
@ -409,6 +406,7 @@ class Postgres(Dialect):
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
@ -445,6 +443,7 @@ class Postgres(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
@ -452,7 +451,6 @@ class Postgres(Dialect):
def bracket_sql(self, expression: exp.Bracket) -> str:
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
if isinstance(expression.this, exp.Array):
expression = expression.copy()
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)

View file

@ -36,7 +36,6 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
expression = expression.copy()
return self.sql(
exp.Join(
this=exp.Unnest(
@ -72,7 +71,6 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
for schema in expression.parent.find_all(exp.Schema):
column_defs = schema.find_all(exp.ColumnDef)
if column_defs and isinstance(schema.parent, exp.Property):
expression = expression.copy()
expression.expressions.extend(column_defs)
return self.schema_sql(expression)
@ -407,12 +405,10 @@ class Presto(Dialect):
target_type = None
if target_type and target_type.is_type("timestamp"):
to = target_type.copy()
if target_type is start.to:
end = exp.cast(end, to)
end = exp.cast(end, target_type)
else:
start = exp.cast(start, to)
start = exp.cast(start, target_type)
return self.func("SEQUENCE", start, end, step)
@ -432,6 +428,5 @@ class Presto(Dialect):
kind = expression.args["kind"]
schema = expression.this
if kind == "VIEW" and schema.expressions:
expression = expression.copy()
expression.this.set("expressions", None)
return super().create_sql(expression)

View file

@ -27,6 +27,14 @@ def _parse_date_add(args: t.List) -> exp.DateAdd:
)
def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
)
class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
@ -51,11 +59,9 @@ class Redshift(Postgres):
),
"DATEADD": _parse_date_add,
"DATE_ADD": _parse_date_add,
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
),
"DATEDIFF": _parse_datediff,
"DATE_DIFF": _parse_datediff,
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@ -175,6 +181,7 @@ class Redshift(Postgres):
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
@ -207,7 +214,6 @@ class Redshift(Postgres):
`TEXT` to `VARCHAR`.
"""
if expression.is_type("text"):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")

View file

@ -32,7 +32,7 @@ def _check_int(s: str) -> bool:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
@ -60,8 +60,8 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]:
# reduce it using `simplify_literals` first and then check if it's a Literal.
first_arg = seq_get(args, 0)
if not isinstance(simplify_literals(first_arg, root=True), Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
# case: <variant_expr> or other expressions such as columns
return exp.TimeStrToTime.from_arg_list(args)
if first_arg.is_string:
if _check_int(first_arg.this):
@ -560,7 +560,6 @@ class Snowflake(Dialect):
offset = expression.args.get("offset")
if offset:
if unnest_alias:
expression = expression.copy()
unnest_alias.append("columns", offset.pop())
selects.append("index")

View file

@ -63,6 +63,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
SUPPORTS_NESTED_CTES = True
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",

View file

@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
binary_from_function,
format_time_lambda,
is_parse_json,
move_insert_cte_sql,
pivot_column_names,
rename_func,
trim_sql,
@ -70,7 +69,9 @@ def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
alias = pivot.args["alias"].pop()
return exp.From(
this=expression.this.replace(
exp.select("*").from_(expression.this.copy()).subquery(alias=alias)
exp.select("*")
.from_(expression.this.copy(), copy=False)
.subquery(alias=alias, copy=False)
)
)
@ -188,7 +189,6 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.Insert: move_insert_cte_sql,
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,

View file

@ -50,7 +50,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
else:
for column in defs.values():
auto_increment = None
for constraint in column.constraints.copy():
for constraint in column.constraints:
if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
break
if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):

View file

@ -38,12 +38,15 @@ class Teradata(Dialect):
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
"COLLECT": TokenType.COMMAND,
"DEL": TokenType.DELETE,
"EQ": TokenType.EQ,
"GE": TokenType.GTE,
"GT": TokenType.GT,
"HELP": TokenType.COMMAND,
"INS": TokenType.INSERT,
"LE": TokenType.LTE,
"LT": TokenType.LT,
"MINUS": TokenType.EXCEPT,
"MOD": TokenType.MOD,
"NE": TokenType.NEQ,
"NOT=": TokenType.NEQ,
@ -51,6 +54,7 @@ class Teradata(Dialect):
"SEL": TokenType.SELECT,
"ST_GEOMETRY": TokenType.GEOMETRY,
"TOP": TokenType.TOP,
"UPD": TokenType.UPDATE,
}
# Teradata does not support % as a modulo operator
@ -181,6 +185,13 @@ class Teradata(Dialect):
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"):
# We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>)
expression.to.pop()
return super().cast_sql(expression, safe_prefix=safe_prefix)
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:

View file

@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
move_insert_cte_sql,
parse_date_delta,
rename_func,
timestrtotime_sql,
@ -158,8 +157,6 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
@ -246,6 +243,7 @@ class TSQL(Dialect):
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dddd": "%A",
"dd": "%d",
"d": "%-d",
"HH": "%H",
@ -596,6 +594,8 @@ class TSQL(Dialect):
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
COMPUTED_COLUMN_WITH_TYPE = False
SUPPORTS_NESTED_CTES = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -622,7 +622,6 @@ class TSQL(Dialect):
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Insert: move_insert_cte_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@ -685,7 +684,6 @@ class TSQL(Dialect):
return sql
def create_sql(self, expression: exp.Create) -> str:
expression = expression.copy()
kind = self.sql(expression, "kind").upper()
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
@ -714,7 +712,7 @@ class TSQL(Dialect):
elif expression.args.get("replace"):
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
return sql
return self.prepend_ctes(expression, sql)
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"

View file

@ -2145,6 +2145,22 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}
# https://www.postgresql.org/docs/current/sql-createtable.html
class PartitionBoundSpec(Expression):
# this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...)
arg_types = {
"this": False,
"expression": False,
"from_expressions": False,
"to_expressions": False,
}
class PartitionedOfProperty(Property):
# this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT
arg_types = {"this": True, "expression": True}
class RemoteWithConnectionModelProperty(Property):
arg_types = {"this": True}
@ -2486,6 +2502,7 @@ class Table(Expression):
"format": False,
"pattern": False,
"index": False,
"ordinality": False,
}
@property
@ -2649,11 +2666,7 @@ class Update(Expression):
class Values(UDTF):
arg_types = {
"expressions": True,
"ordinality": False,
"alias": False,
}
arg_types = {"expressions": True, "alias": False}
class Var(Expression):
@ -3501,7 +3514,7 @@ class Star(Expression):
class Parameter(Condition):
arg_types = {"this": True, "wrapped": False}
arg_types = {"this": True, "expression": False}
class SessionParameter(Condition):
@ -5036,7 +5049,7 @@ class FromBase(Func):
class Struct(Func):
arg_types = {"expressions": True}
arg_types = {"expressions": False}
is_var_len_args = True
@ -5171,7 +5184,7 @@ class Use(Expression):
class Merge(Expression):
arg_types = {"this": True, "using": True, "on": True, "expressions": True}
arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False}
class When(Func):
@ -5459,7 +5472,12 @@ def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
def union(
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
left: ExpOrStr,
right: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Union:
"""
Initializes a syntax tree from one UNION expression.
@ -5475,19 +5493,25 @@ def union(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Union instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Union(this=left, expression=right, distinct=distinct)
def intersect(
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
left: ExpOrStr,
right: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Intersect:
"""
Initializes a syntax tree from one INTERSECT expression.
@ -5503,19 +5527,25 @@ def intersect(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Intersect instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Intersect(this=left, expression=right, distinct=distinct)
def except_(
left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts
left: ExpOrStr,
right: ExpOrStr,
distinct: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Except:
"""
Initializes a syntax tree from one EXCEPT expression.
@ -5531,13 +5561,14 @@ def except_(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
copy: whether or not to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
The new Except instance.
"""
left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts)
left = maybe_parse(sql_or_expression=left, dialect=dialect, copy=copy, **opts)
right = maybe_parse(sql_or_expression=right, dialect=dialect, copy=copy, **opts)
return Except(this=left, expression=right, distinct=distinct)
@ -5861,7 +5892,7 @@ def to_identifier(name, quoted=None, copy=True):
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
copy: Whether or not to copy a passed in Identefier node.
copy: Whether or not to copy name if it's an Identifier.
Returns:
The identifier ast node.
@ -5882,6 +5913,25 @@ def to_identifier(name, quoted=None, copy=True):
return identifier
def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
"""
Parses a given string into an identifier.
Args:
name: The name to parse into an identifier.
dialect: The dialect to parse against.
Returns:
The identifier ast node.
"""
try:
expression = maybe_parse(name, dialect=dialect, into=Identifier)
except ParseError:
expression = to_identifier(name)
return expression
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")

View file

@ -230,6 +230,12 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
# Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
SUPPORTS_NESTED_CTES = True
# Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -304,6 +310,7 @@ class Generator:
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
@ -407,7 +414,6 @@ class Generator:
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_cache",
)
def __init__(
@ -447,30 +453,38 @@ class Generator:
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
expression: t.Optional[exp.Expression],
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
cache: An optional sql string cache. This leverages the hash of an Expression
which can be slow to compute, so only use it if you set _hash on each node.
copy: Whether or not to copy the expression. The generator performs mutations so
it is safer to copy.
Returns:
The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
if copy:
expression = expression.copy()
# Some dialects only support CTEs at the top level expression, so we need to bubble up nested
# CTEs to that level in order to produce a syntactically valid expression. This transformation
# happens here to minimize code duplication, since many expressions support CTEs.
if (
not self.SUPPORTS_NESTED_CTES
and isinstance(expression, exp.Expression)
and not expression.parent
and "with" in expression.arg_types
and any(node.parent is not expression for node in expression.find_all(exp.With))
):
from sqlglot.transforms import move_ctes_to_top_level
expression = move_ctes_to_top_level(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@ -595,12 +609,6 @@ class Generator:
return self.sql(value)
return ""
if self._cache is not None:
expression_id = hash(expression)
if expression_id in self._cache:
return self._cache[expression_id]
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@ -621,11 +629,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
return sql
return self.maybe_comment(sql, expression) if self.comments and comment else sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@ -879,7 +883,11 @@ class Generator:
def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
recursive = (
"RECURSIVE "
if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive")
else ""
)
return f"WITH {recursive}{sql}"
@ -1022,7 +1030,7 @@ class Generator:
where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
agg = expression.this.copy()
agg = expression.this
agg_arg = agg.this
cond = expression.expression.this
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
@ -1088,9 +1096,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
with_properties.append(p.copy())
with_properties.append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA:
root_properties.append(p.copy())
root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
@ -1124,7 +1132,7 @@ class Generator:
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc != exp.Properties.Location.UNSUPPORTED:
properties_locs[p_loc].append(p.copy())
properties_locs[p_loc].append(p)
else:
self.unsupported(f"Unsupported property {p.key}")
@ -1238,6 +1246,29 @@ class Generator:
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
if isinstance(expression.this, list):
return f"IN ({self.expressions(expression, key='this', flat=True)})"
if expression.this:
modulus = self.sql(expression, "this")
remainder = self.sql(expression, "expression")
return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"
from_expressions = self.expressions(expression, key="from_expressions", flat=True)
to_expressions = self.expressions(expression, key="to_expressions", flat=True)
return f"FROM ({from_expressions}) TO ({to_expressions})"
def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
this = self.sql(expression, "this")
for_values_or_default = expression.expression
if isinstance(for_values_or_default, exp.PartitionBoundSpec):
for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
else:
for_values_or_default = " DEFAULT"
return f"PARTITION OF {this}{for_values_or_default}"
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
@ -1385,7 +1416,12 @@ class Generator:
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
ordinality = expression.args.get("ordinality") or ""
if ordinality:
ordinality = f" WITH ORDINALITY{alias}"
alias = ""
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@ -1489,7 +1525,6 @@ class Generator:
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
expression = expression.copy()
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns
@ -1972,8 +2007,7 @@ class Generator:
if self.UNNEST_WITH_ORDINALITY:
if alias and isinstance(offset, exp.Expression):
alias = alias.copy()
alias.append("columns", offset.copy())
alias.append("columns", offset)
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
@ -2138,7 +2172,6 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@ -2367,7 +2400,9 @@ class Generator:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
to_sql = self.sql(expression, "to")
to_sql = f" {to_sql}" if to_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@ -2510,7 +2545,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
@ -2779,7 +2814,6 @@ class Generator:
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
@ -2787,7 +2821,9 @@ class Generator:
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ")
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
return self.prepend_ctes(
expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
)
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@ -2896,12 +2932,12 @@ class Generator:
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
expression.args["true"].copy(),
expression.args["true"],
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
case.else_(else_cond.copy(), copy=False)
case.else_(else_cond, copy=False)
return self.sql(case)
@ -2931,15 +2967,6 @@ class Generator:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression.copy())
expression = simplify(expression)
return expression
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
) -> t.Callable[[exp.Expression], str]:
"""Returns a cached generator."""
cache = {} if cache is None else cache
generator = Generator(normalize=True, identify="safe")
return lambda e: generator.generate(e, cache)

View file

@ -184,9 +184,7 @@ def apply_index_offset(
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset)))
return [expression]
return expressions

View file

@ -4,7 +4,6 @@ import logging
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
try:
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
)
except OptimizeError as e:
logger.info(e)
@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance, generate):
def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func, generate)
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func, generate)
return _distribute(a, b, from_func, to_func)
return expression
def _distribute(a, b, from_func, to_func, generate):
def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), generate),
uniq_sort(flatten(from_func(c, b.right)), generate),
uniq_sort(flatten(from_func(c, b.left))),
uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), generate),
uniq_sort(flatten(from_func(a, b.right)), generate),
uniq_sort(flatten(from_func(a, b.left))),
uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parse_one
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
expression = exp.parse_identifier(expression, dialect=dialect)
dialect = Dialect.get_or_raise(dialect)

View file

@ -62,7 +62,7 @@ def qualify_tables(
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:

View file

@ -7,8 +7,7 @@ from decimal import Decimal
import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
generate = cached_generator()
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
return expression
def uniq_sort(expression, generate, root=True):
def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {generate(e): e for e in flattened}
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
def gen(expression: t.Any) -> str:
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
if expression is None:
return "_"
if is_iterable(expression):
return ",".join(gen(e) for e in expression)
if not isinstance(expression, exp.Expression):
return str(expression)
etype = type(expression)
if etype in GEN_MAP:
return GEN_MAP[etype](expression)
return f"{expression.key} {gen(expression.args.values())}"
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
exp.DPipe: lambda e: _binary(e, "||"),
exp.SafeDPipe: lambda e: _binary(e, "||"),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
exp.ILike: lambda e: _binary(e, "ILIKE"),
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
exp.Is: lambda e: _binary(e, "IS"),
exp.Like: lambda e: _binary(e, "LIKE"),
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
exp.LT: lambda e: _binary(e, "<"),
exp.LTE: lambda e: _binary(e, "<="),
exp.Mod: lambda e: _binary(e, "%"),
exp.Mul: lambda e: _binary(e, "*"),
exp.Neg: lambda e: _unary(e, "-"),
exp.NEQ: lambda e: _binary(e, "<>"),
exp.Not: lambda e: _unary(e, "NOT"),
exp.Null: lambda e: "NULL",
exp.Or: lambda e: _binary(e, "OR"),
exp.Paren: lambda e: f"({gen(e.this)})",
exp.Sub: lambda e: _binary(e, "-"),
exp.Subquery: lambda e: f"({gen(e.args.values())})",
exp.Table: lambda e: gen(e.args.values()),
exp.Var: lambda e: e.name,
}
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
def _unary(e: exp.Unary, op: str) -> str:
return f"{op} {gen(e.this)}"

View file

@ -674,6 +674,7 @@ class Parser(metaclass=_Parser):
"ON": lambda self: self._parse_on_property(),
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
"PARTITION": lambda self: self._parse_partitioned_of(),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
@ -1743,6 +1744,58 @@ class Parser(metaclass=_Parser):
return self._parse_csv(self._parse_conjunction)
return []
def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec:
def _parse_partition_bound_expr() -> t.Optional[exp.Expression]:
if self._match_text_seq("MINVALUE"):
return exp.var("MINVALUE")
if self._match_text_seq("MAXVALUE"):
return exp.var("MAXVALUE")
return self._parse_bitwise()
this: t.Optional[exp.Expression | t.List[exp.Expression]] = None
expression = None
from_expressions = None
to_expressions = None
if self._match(TokenType.IN):
this = self._parse_wrapped_csv(self._parse_bitwise)
elif self._match(TokenType.FROM):
from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
self._match_text_seq("TO")
to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
elif self._match_text_seq("WITH", "(", "MODULUS"):
this = self._parse_number()
self._match_text_seq(",", "REMAINDER")
expression = self._parse_number()
self._match_r_paren()
else:
self.raise_error("Failed to parse partition bound spec.")
return self.expression(
exp.PartitionBoundSpec,
this=this,
expression=expression,
from_expressions=from_expressions,
to_expressions=to_expressions,
)
# https://www.postgresql.org/docs/current/sql-createtable.html
def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]:
if not self._match_text_seq("OF"):
self._retreat(self._index - 1)
return None
this = self._parse_table(schema=True)
if self._match(TokenType.DEFAULT):
expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT")
elif self._match_text_seq("FOR", "VALUES"):
expression = self._parse_partition_bound_spec()
else:
self.raise_error("Expecting either DEFAULT or FOR VALUES clause.")
return self.expression(exp.PartitionedOfProperty, this=this, expression=expression)
def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
self._match(TokenType.EQ)
return self.expression(
@ -2682,6 +2735,10 @@ class Parser(metaclass=_Parser):
for join in iter(self._parse_join, None):
this.append("joins", join)
if self._match_pair(TokenType.WITH, TokenType.ORDINALITY):
this.set("ordinality", True)
this.set("alias", self._parse_table_alias())
return this
def _parse_version(self) -> t.Optional[exp.Version]:
@ -4189,17 +4246,12 @@ class Parser(metaclass=_Parser):
fmt = None
to = self._parse_types()
if not to:
self.raise_error("Expected TYPE after CAST")
elif isinstance(to, exp.Identifier):
to = exp.DataType.build(to.name, udt=True)
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
elif self._match(TokenType.FORMAT):
if self._match(TokenType.FORMAT):
fmt_string = self._parse_string()
fmt = self._parse_at_time_zone(fmt_string)
if not to:
to = exp.DataType.build(exp.DataType.Type.UNKNOWN)
if to.this in exp.DataType.TEMPORAL_TYPES:
this = self.expression(
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
@ -4215,8 +4267,14 @@ class Parser(metaclass=_Parser):
if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
this.set("zone", fmt.args["zone"])
return this
elif not to:
self.raise_error("Expected TYPE after CAST")
elif isinstance(to, exp.Identifier):
to = exp.DataType.build(to.name, udt=True)
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
return self.expression(
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
@ -4789,10 +4847,17 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
wrapped = self._match(TokenType.L_BRACE)
this = self._parse_var() or self._parse_identifier() or self._parse_primary()
def _parse_parameter_part() -> t.Optional[exp.Expression]:
return (
self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True)
)
self._match(TokenType.L_BRACE)
this = _parse_parameter_part()
expression = self._match(TokenType.COLON) and _parse_parameter_part()
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, wrapped=wrapped)
return self.expression(exp.Parameter, this=this, expression=expression)
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PLACEHOLDER_PARSERS):

View file

@ -3,10 +3,9 @@ from __future__ import annotations
import abc
import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import TrieResult, in_trie, new_trie
@ -448,19 +447,16 @@ class MappingSchema(AbstractMappingSchema, Schema):
def normalize_name(
name: str | exp.Identifier,
identifier: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
) -> str:
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
if isinstance(identifier, str):
identifier = exp.parse_identifier(identifier, dialect=dialect)
name = identifier.name
if not normalize:
return name
return identifier.name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table

View file

@ -67,7 +67,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
order = expression.args.get("order")
if order:
window.set("order", order.pop().copy())
window.set("order", order.pop())
else:
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
@ -75,9 +75,9 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
expression.select(window, copy=False)
return (
exp.select(*outer_selects)
.from_(expression.subquery("_t"))
.where(exp.column(row_number).eq(1))
exp.select(*outer_selects, copy=False)
.from_(expression.subquery("_t", copy=False), copy=False)
.where(exp.column(row_number).eq(1), copy=False)
)
return expression
@ -120,7 +120,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
qualify_filters, copy=False
)
return expression
@ -189,7 +191,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
)
# we use list here because expression.selects is mutated inside the loop
for select in expression.selects.copy():
for select in list(expression.selects):
explode = select.find(exp.Explode)
if explode:
@ -374,6 +376,60 @@ def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
return expression
def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
"""
Converts a query with a FULL OUTER join to a union of identical queries that
use LEFT/RIGHT OUTER joins instead. This transformation currently only works
for queries that have a single FULL OUTER join.
"""
if isinstance(expression, exp.Select):
full_outer_joins = [
(index, join)
for index, join in enumerate(expression.args.get("joins") or [])
if join.side == "FULL" and join.kind == "OUTER"
]
if len(full_outer_joins) == 1:
expression_copy = expression.copy()
index, full_outer_join = full_outer_joins[0]
full_outer_join.set("side", "left")
expression_copy.args["joins"][index].set("side", "right")
return exp.union(expression, expression_copy, copy=False)
return expression
def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
"""
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are
moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
"""
top_level_with = expression.args.get("with")
for node in expression.find_all(exp.With):
if node.parent is expression:
continue
inner_with = node.pop()
if not top_level_with:
top_level_with = inner_with
expression.set("with", top_level_with)
else:
if inner_with.recursive:
top_level_with.set("recursive", True)
top_level_with.expressions.extend(inner_with.expressions)
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
@ -392,7 +448,7 @@ def preprocess(
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
expression = transforms[0](expression.copy())
expression = transforms[0](expression)
for t in transforms[1:]:
expression = t(expression)