1
0
Fork 0

Merging upstream version 23.12.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:32:41 +01:00
parent 1271e5fe1c
commit 740634a4e8
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 55455 additions and 52777 deletions

View file

@ -222,7 +222,6 @@ class BigQuery(Dialect):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time
TIME_MAPPING = {
"%D": "%m/%d/%y",
"%E*S": "%S.%f",
"%E6S": "%S.%f",
}
@ -474,11 +473,31 @@ class BigQuery(Dialect):
if rest and this:
this = exp.Dot.build([this, *rest]) # type: ignore
table = exp.Table(this=this, db=db, catalog=catalog)
table = exp.Table(
this=this, db=db, catalog=catalog, pivots=table.args.get("pivots")
)
table.meta["quoted_table"] = True
return table
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if isinstance(column, exp.Column):
parts = column.parts
if any("." in p.name for p in parts):
catalog, db, table, this, *rest = (
exp.to_identifier(p, quoted=True)
for p in split_num_words(".".join(p.name for p in parts), ".", 4)
)
if rest and this:
this = exp.Dot.build([this, *rest]) # type: ignore
column = exp.Column(this=this, table=table, db=db, catalog=catalog)
column.meta["quoted_column"] = True
return column
@t.overload
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@ -670,6 +689,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.ROWVERSION: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
@ -781,6 +801,16 @@ class BigQuery(Dialect):
"within",
}
def column_parts(self, expression: exp.Column) -> str:
if expression.meta.get("quoted_column"):
# If a column reference is of the form `dataset.table`.name, we need
# to preserve the quoted table path, otherwise the reference breaks
table_parts = ".".join(p.name for p in expression.parts[:-1])
table_path = self.sql(exp.Identifier(this=table_parts, quoted=True))
return f"{table_path}.{self.sql(expression, 'this')}"
return super().column_parts(expression)
def table_parts(self, expression: exp.Table) -> str:
# Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so
# we need to make sure the correct quoting is used in each case.

View file

@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
build_formatted_time,
date_delta_sql,
inline_array_sql,
json_extract_segments,
@ -19,6 +20,16 @@ from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
def _build_date_format(args: t.List) -> exp.TimeToStr:
expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args)
timezone = seq_get(args, 2)
if timezone:
expr.set("timezone", timezone)
return expr
def _lower_func(sql: str) -> str:
index = sql.index("(")
return sql[:index].lower() + sql[index:]
@ -124,6 +135,8 @@ class ClickHouse(Dialect):
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_FORMAT": _build_date_format,
"FORMATDATETIME": _build_date_format,
"JSONEXTRACTSTRING": build_json_extract_path(
exp.JSONExtractScalar, zero_based_indexing=False
),
@ -241,6 +254,14 @@ class ClickHouse(Dialect):
"sparkBar",
"sumCount",
"largestTriangleThreeBuckets",
"histogram",
"sequenceMatch",
"sequenceCount",
"windowFunnel",
"retention",
"uniqUpTo",
"sequenceNextNode",
"exponentialTimeDecayedAvg",
}
AGG_FUNCTIONS_SUFFIXES = [
@ -383,6 +404,7 @@ class ClickHouse(Dialect):
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
parse_partition: bool = False,
) -> t.Optional[exp.Expression]:
this = super()._parse_table(
schema=schema,
@ -447,46 +469,53 @@ class ClickHouse(Dialect):
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
any_token: bool = False,
) -> t.Optional[exp.Expression]:
func = super()._parse_function(
functions=functions, anonymous=anonymous, optional_parens=optional_parens
expr = super()._parse_function(
functions=functions,
anonymous=anonymous,
optional_parens=optional_parens,
any_token=any_token,
)
if isinstance(func, exp.Anonymous):
parts = self.AGG_FUNC_MAPPING.get(func.this)
func = expr.this if isinstance(expr, exp.Window) else expr
# Aggregate functions can be split in 2 parts: <func_name><suffix>
parts = (
self.AGG_FUNC_MAPPING.get(func.this) if isinstance(func, exp.Anonymous) else None
)
if parts:
params = self._parse_func_params(func)
kwargs = {
"this": func.this,
"expressions": func.expressions,
}
if parts[1]:
kwargs["parts"] = parts
exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
else:
exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc
kwargs["exp_class"] = exp_class
if params:
if parts and parts[1]:
return self.expression(
exp.CombinedParameterizedAgg,
this=func.this,
expressions=func.expressions,
params=params,
parts=parts,
)
return self.expression(
exp.ParameterizedAgg,
this=func.this,
expressions=func.expressions,
params=params,
)
kwargs["params"] = params
if parts:
if parts[1]:
return self.expression(
exp.CombinedAggFunc,
this=func.this,
expressions=func.expressions,
parts=parts,
)
return self.expression(
exp.AnonymousAggFunc,
this=func.this,
expressions=func.expressions,
)
func = self.expression(**kwargs)
return func
if isinstance(expr, exp.Window):
# The window's func was parsed as Anonymous in base parser, fix its
# type to be CH style CombinedAnonymousAggFunc / AnonymousAggFunc
expr.set("this", func)
elif params:
# Params have blocked super()._parse_function() from parsing the following window
# (if that exists) as they're standing between the function call and the window spec
expr = self._parse_window(func)
else:
expr = func
return expr
def _parse_func_params(
self, this: t.Optional[exp.Func] = None
@ -653,6 +682,9 @@ class ClickHouse(Dialect):
exp.StrPosition: lambda self, e: self.func(
"position", e.this, e.args.get("substr"), e.args.get("position")
),
exp.TimeToStr: lambda self, e: self.func(
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone")
),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}

View file

@ -568,7 +568,7 @@ def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> st
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression, flat=True)}]"
return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:

View file

@ -28,7 +28,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
unit_to_var,
)
from sqlglot.helper import flatten, seq_get
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -155,16 +155,6 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)))
def _rename_unless_within_group(
a: str, b: str
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
return lambda self, expression: (
self.func(a, *flatten(expression.args.values()))
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
else self.func(b, *flatten(expression.args.values()))
)
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
@ -425,8 +415,8 @@ class DuckDB(Dialect):
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
exp.PercentileCont: rename_func("QUANTILE_CONT"),
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
@ -499,6 +489,7 @@ class DuckDB(Dialect):
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.UINT: "UINTEGER",
exp.DataType.Type.VARBINARY: "BLOB",
exp.DataType.Type.ROWVERSION: "BLOB",
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
@ -616,3 +607,19 @@ class DuckDB(Dialect):
bracket = f"({bracket})[1]"
return bracket
def withingroup_sql(self, expression: exp.WithinGroup) -> str:
expression_sql = self.sql(expression, "expression")
func = expression.this
if isinstance(func, exp.PERCENTILES):
# Make the order key the first arg and slide the fraction to the right
# https://duckdb.org/docs/sql/aggregates#ordered-set-aggregate-functions
order_col = expression.find(exp.Ordered)
if order_col:
func.set("expression", func.this)
func.set("this", order_col.this)
this = self.sql(expression, "this").rstrip(")")
return f"{this}{expression_sql})"

View file

@ -457,6 +457,7 @@ class Hive(Dialect):
exp.DataType.Type.TIME: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
exp.DataType.Type.ROWVERSION: "BINARY",
}
TRANSFORMS = {

View file

@ -443,6 +443,7 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
STRING_ALIASES = True
VALUES_FOLLOWED_BY_PAREN = False
SUPPORTS_PARTITION_SELECTION = True
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()

View file

@ -222,6 +222,7 @@ class Oracle(Dialect):
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
exp.DataType.Type.ROWVERSION: "BLOB",
}
TRANSFORMS = {

View file

@ -431,6 +431,7 @@ class Postgres(Dialect):
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.ROWVERSION: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}

View file

@ -443,6 +443,67 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
RESERVED_KEYWORDS = {
"alter",
"and",
"as",
"between",
"by",
"case",
"cast",
"constraint",
"create",
"cross",
"current_time",
"current_timestamp",
"deallocate",
"delete",
"describe",
"distinct",
"drop",
"else",
"end",
"escape",
"except",
"execute",
"exists",
"extract",
"false",
"for",
"from",
"full",
"group",
"having",
"in",
"inner",
"insert",
"intersect",
"into",
"is",
"join",
"left",
"like",
"natural",
"not",
"null",
"on",
"or",
"order",
"outer",
"prepare",
"right",
"select",
"table",
"then",
"true",
"union",
"using",
"values",
"when",
"where",
"with",
}
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a

View file

@ -55,6 +55,20 @@ class PRQL(Dialect):
"SORT": lambda self, query: self._parse_order_by(query),
}
def _parse_equality(self) -> t.Optional[exp.Expression]:
eq = self._parse_tokens(self._parse_comparison, self.EQUALITY)
if not isinstance(eq, (exp.EQ, exp.NEQ)):
return eq
# https://prql-lang.org/book/reference/spec/null.html
if isinstance(eq.expression, exp.Null):
is_exp = exp.Is(this=eq.this, expression=eq.expression)
return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp)
if isinstance(eq.this, exp.Null):
is_exp = exp.Is(this=eq.expression, expression=eq.this)
return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp)
return eq
def _parse_statement(self) -> t.Optional[exp.Expression]:
expression = self._parse_expression()
expression = expression if expression else self._parse_query()
@ -136,6 +150,7 @@ class PRQL(Dialect):
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
parse_partition: bool = False,
) -> t.Optional[exp.Expression]:
return self._parse_table_parts()

View file

@ -79,6 +79,7 @@ class Redshift(Postgres):
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
parse_partition: bool = False,
) -> t.Optional[exp.Expression]:
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
unpivot = self._match(TokenType.UNPIVOT)
@ -145,6 +146,7 @@ class Redshift(Postgres):
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.ROWVERSION: "VARBYTE",
}
TRANSFORMS = {
@ -196,7 +198,165 @@ class Redshift(Postgres):
# Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
RESERVED_KEYWORDS = {
"aes128",
"aes256",
"all",
"allowoverwrite",
"analyse",
"analyze",
"and",
"any",
"array",
"as",
"asc",
"authorization",
"az64",
"backup",
"between",
"binary",
"blanksasnull",
"both",
"bytedict",
"bzip2",
"case",
"cast",
"check",
"collate",
"column",
"constraint",
"create",
"credentials",
"cross",
"current_date",
"current_time",
"current_timestamp",
"current_user",
"current_user_id",
"default",
"deferrable",
"deflate",
"defrag",
"delta",
"delta32k",
"desc",
"disable",
"distinct",
"do",
"else",
"emptyasnull",
"enable",
"encode",
"encrypt ",
"encryption",
"end",
"except",
"explicit",
"false",
"for",
"foreign",
"freeze",
"from",
"full",
"globaldict256",
"globaldict64k",
"grant",
"group",
"gzip",
"having",
"identity",
"ignore",
"ilike",
"in",
"initially",
"inner",
"intersect",
"interval",
"into",
"is",
"isnull",
"join",
"leading",
"left",
"like",
"limit",
"localtime",
"localtimestamp",
"lun",
"luns",
"lzo",
"lzop",
"minus",
"mostly16",
"mostly32",
"mostly8",
"natural",
"new",
"not",
"notnull",
"null",
"nulls",
"off",
"offline",
"offset",
"oid",
"old",
"on",
"only",
"open",
"or",
"order",
"outer",
"overlaps",
"parallel",
"partition",
"percent",
"permissions",
"pivot",
"placing",
"primary",
"raw",
"readratio",
"recover",
"references",
"rejectlog",
"resort",
"respect",
"restore",
"right",
"select",
"session_user",
"similar",
"snapshot",
"some",
"sysdate",
"system",
"table",
"tag",
"tdes",
"text255",
"text32k",
"then",
"timestamp",
"to",
"top",
"trailing",
"true",
"truncatecolumns",
"type",
"union",
"unique",
"unnest",
"unpivot",
"user",
"using",
"verbose",
"wallet",
"when",
"where",
"with",
"without",
}
def unnest_sql(self, expression: exp.Unnest) -> str:
args = expression.expressions

View file

@ -33,10 +33,9 @@ def _build_datetime(
) -> t.Callable[[t.List], exp.Func]:
def _builder(args: t.List) -> exp.Func:
value = seq_get(args, 0)
int_value = value is not None and is_int(value.name)
if isinstance(value, exp.Literal):
int_value = is_int(value.this)
# Converts calls like `TO_TIME('01:02:03')` into casts
if len(args) == 1 and value.is_string and not int_value:
return exp.cast(value, kind)
@ -49,7 +48,7 @@ def _build_datetime(
if not is_float(value.this):
return build_formatted_time(exp.StrToTime, "snowflake")(args)
if len(args) == 2 and kind == exp.DataType.Type.DATE:
if kind == exp.DataType.Type.DATE and not int_value:
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
formatted_exp.set("safe", safe)
return formatted_exp
@ -749,6 +748,7 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,

View file

@ -464,6 +464,7 @@ class TSQL(Dialect):
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"TOP": TokenType.TOP,
"TIMESTAMP": TokenType.ROWVERSION,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"UPDATE STATISTICS": TokenType.COMMAND,
"XML": TokenType.XML,
@ -755,6 +756,7 @@ class TSQL(Dialect):
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
exp.DataType.Type.ROWVERSION: "ROWVERSION",
}
TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
@ -1052,3 +1054,9 @@ class TSQL(Dialect):
def partition_sql(self, expression: exp.Partition) -> str:
return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"
def altertable_sql(self, expression: exp.AlterTable) -> str:
action = seq_get(expression.args.get("actions") or [], 0)
if isinstance(action, exp.RenameTable):
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
return super().altertable_sql(expression)

View file

@ -301,9 +301,10 @@ class Expression(metaclass=_Expression):
"""
return deepcopy(self)
def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
def add_comments(self, comments: t.Optional[t.List[str]] = None) -> None:
if self.comments is None:
self.comments = []
if comments:
for comment in comments:
_, *meta = comment.split(SQLGLOT_META)
@ -314,6 +315,11 @@ class Expression(metaclass=_Expression):
self.meta[k.strip()] = value
self.comments.append(comment)
def pop_comments(self) -> t.List[str]:
comments = self.comments or []
self.comments = None
return comments
def append(self, arg_key: str, value: t.Any) -> None:
"""
Appends value to arg_key if it's a list or sets it as a new list.
@ -2058,11 +2064,11 @@ class Insert(DDL, DML):
"returning": False,
"overwrite": False,
"exists": False,
"partition": False,
"alternative": False,
"where": False,
"ignore": False,
"by_name": False,
"stored": False,
}
def with_(
@ -2911,6 +2917,7 @@ class Table(Expression):
"ordinality": False,
"when": False,
"only": False,
"partition": False,
}
@property
@ -5683,7 +5690,7 @@ class StddevSamp(AggFunc):
class TimeToStr(Func):
arg_types = {"this": True, "format": True, "culture": False}
arg_types = {"this": True, "format": True, "culture": False, "timezone": False}
class TimeToTimeStr(Func):
@ -5873,6 +5880,8 @@ FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_na
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
PERCENTILES = (PercentileCont, PercentileDisc)
# Helpers
@t.overload

View file

@ -349,6 +349,7 @@ class Generator(metaclass=_Generator):
exp.DataType.Type.LONGBLOB: "BLOB",
exp.DataType.Type.TINYBLOB: "BLOB",
exp.DataType.Type.INET: "INET",
exp.DataType.Type.ROWVERSION: "VARBINARY",
}
STAR_MAPPING = {
@ -644,6 +645,7 @@ class Generator(metaclass=_Generator):
sql: str,
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
separated: bool = False,
) -> str:
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
@ -661,7 +663,9 @@ class Generator(metaclass=_Generator):
if not comments_sql:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
comments_sql = self._replace_line_breaks(comments_sql)
if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
if not sql or sql[0].isspace()
@ -778,14 +782,8 @@ class Generator(metaclass=_Generator):
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
column = ".".join(
def column_parts(self, expression: exp.Column) -> str:
return ".".join(
self.sql(part)
for part in (
expression.args.get("catalog"),
@ -796,7 +794,14 @@ class Generator(metaclass=_Generator):
if part
)
return f"{column}{join_mark}"
def column_sql(self, expression: exp.Column) -> str:
join_mark = " (+)" if expression.args.get("join_mark") else ""
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
join_mark = ""
self.unsupported("Outer join syntax using the (+) operator is not supported.")
return f"{self.column_parts(expression)}{join_mark}"
def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
this = self.sql(expression, "this")
@ -1520,6 +1525,8 @@ class Generator(metaclass=_Generator):
else:
this = self.INSERT_OVERWRITE if overwrite else " INTO"
stored = self.sql(expression, "stored")
stored = f" {stored}" if stored else ""
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
ignore = " IGNORE" if expression.args.get("ignore") else ""
@ -1529,9 +1536,6 @@ class Generator(metaclass=_Generator):
this = f"{this} {self.sql(expression, 'this')}"
exists = " IF EXISTS" if expression.args.get("exists") else ""
partition_sql = (
f" {self.sql(expression, 'partition')}" if expression.args.get("partition") else ""
)
where = self.sql(expression, "where")
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
@ -1545,7 +1549,7 @@ class Generator(metaclass=_Generator):
else:
expression_sql = f"{returning}{expression_sql}{on_conflict}"
sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@ -1634,6 +1638,8 @@ class Generator(metaclass=_Generator):
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = self.table_parts(expression)
only = "ONLY " if expression.args.get("only") else ""
partition = self.sql(expression, "partition")
partition = f" {partition}" if partition else ""
version = self.sql(expression, "version")
version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
@ -1662,7 +1668,7 @@ class Generator(metaclass=_Generator):
if when:
table = f"{table} {when}"
return f"{only}{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
return f"{only}{table}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self,
@ -2017,10 +2023,9 @@ class Generator(metaclass=_Generator):
to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text
)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
return self._replace_line_breaks(text).replace(
self.dialect.QUOTE_END, self._escaped_quote_end
)
def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
@ -2341,8 +2346,8 @@ class Generator(metaclass=_Generator):
stack.append(
self.maybe_comment(
getattr(self, f"{node.key}_op")(node),
expression=node.this,
comments=node.comments,
separated=True,
)
)
stack.append(node.this)
@ -2486,7 +2491,7 @@ class Generator(metaclass=_Generator):
statements.append("END")
if self.pretty and self.text_width(statements) > self.max_text_width:
if self.pretty and self.too_wide(statements):
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
@ -2847,7 +2852,7 @@ class Generator(metaclass=_Generator):
else:
sqls.append(sql)
sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " "
sep = "\n" if self.pretty and self.too_wide(sqls) else " "
return sep.join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
@ -3208,12 +3213,12 @@ class Generator(metaclass=_Generator):
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
if self.pretty and self.too_wide(arg_sqls):
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
def too_wide(self, args: t.Iterable) -> bool:
return sum(len(arg) for arg in args) > self.max_text_width
def format_time(
self,
@ -3235,8 +3240,11 @@ class Generator(metaclass=_Generator):
flat: bool = False,
indent: bool = True,
skip_first: bool = False,
skip_last: bool = False,
sep: str = ", ",
prefix: str = "",
dynamic: bool = False,
new_line: bool = False,
) -> str:
expressions = expression.args.get(key or "expressions") if expression else sqls
@ -3270,8 +3278,18 @@ class Generator(metaclass=_Generator):
else:
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sql, skip_first=skip_first) if indent else result_sql
if self.pretty and (not dynamic or self.too_wide(result_sqls)):
if new_line:
result_sqls.insert(0, "")
result_sqls.append("")
result_sql = "\n".join(result_sqls)
else:
result_sql = "".join(result_sqls)
return (
self.indent(result_sql, skip_first=skip_first, skip_last=skip_last)
if indent
else result_sql
)
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
flat = flat or isinstance(expression.parent, exp.Properties)
@ -3733,3 +3751,9 @@ class Generator(metaclass=_Generator):
return self.sql(agg_func)[:-1] + f" {text})"
return f"{self.sql(expression, 'this')} {text}"
def _replace_line_breaks(self, string: str) -> str:
"""We don't want to extra indent line breaks so we temporarily replace them with sentinels."""
if self.pretty:
return string.replace("\n", self.SENTINEL_LINE_BREAK)
return string

View file

@ -351,55 +351,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
self.annotate_scope(scope)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def annotate_scope(self, scope: Scope) -> None:
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
self._set_type(col, self.schema.get_column_type(source, col))
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
def _maybe_annotate(self, expression: E) -> E:
if id(expression) in self._visited:
return expression # We've already inferred the expression's type
@ -601,7 +603,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
self._annotate_args(expression)
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
if child and child.is_type(exp.DataType.Type.ARRAY):
expr_type = seq_get(child.type.expressions, 0)
else:
expr_type = None
self._set_type(expression, expr_type)
return expression
def _annotate_struct_value(

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import inspect
import typing as t
import sqlglot
@ -85,7 +86,7 @@ def optimize(
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_params = inspect.getfullargspec(rule).args
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}

View file

@ -7,7 +7,7 @@ from sqlglot import alias, exp
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get, SingleValuedMapping
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.annotate_types import TypeAnnotator
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@ -49,8 +49,10 @@ def qualify_columns(
- Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
annotator = TypeAnnotator(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
dialect = Dialect.get_or_raise(schema.dialect)
pseudocolumns = dialect.PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@ -74,6 +76,9 @@ def qualify_columns(
_expand_group_by(scope)
_expand_order_by(scope, resolver)
if dialect == "bigquery":
annotator.annotate_scope(scope)
return expression
@ -660,11 +665,8 @@ class Resolver:
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
expression = source.expression
annotate_types(expression)
if expression.is_type(exp.DataType.Type.STRUCT):
for k in expression.type.expressions: # type: ignore
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects

View file

@ -6,6 +6,7 @@ import itertools
import typing as t
from collections import deque
from decimal import Decimal
from functools import reduce
import sqlglot
from sqlglot import Dialect, exp
@ -658,17 +659,21 @@ def simplify_parens(expression):
parent = expression.parent
parent_is_predicate = isinstance(parent, exp.Predicate)
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
if (
not isinstance(this, exp.Select)
and not isinstance(parent, exp.SubqueryPredicate)
and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(parent, exp.Paren)
or (
not isinstance(this, exp.Binary)
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
)
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
return this
return expression
@ -779,6 +784,8 @@ def simplify_concat(expression):
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
elif isinstance(expression, exp.DPipe):
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
return concat_type(expressions=new_args, **args)

View file

@ -289,7 +289,7 @@ class Parser(metaclass=_Parser):
RESERVED_TOKENS = {
*Tokenizer.SINGLE_TOKENS.values(),
TokenType.SELECT,
}
} - {TokenType.IDENTIFIER}
DB_CREATABLES = {
TokenType.DATABASE,
@ -1109,6 +1109,9 @@ class Parser(metaclass=_Parser):
# Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS
INTERVAL_SPANS = True
# Whether a PARTITION clause can follow a table reference
SUPPORTS_PARTITION_SELECTION = False
__slots__ = (
"error_level",
"error_message_context",
@ -1764,7 +1767,11 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_field(), **kwargs)
field = self._parse_field()
if isinstance(field, exp.Identifier) and not field.quoted:
field = exp.var(field)
return self.expression(exp_class, this=field, **kwargs)
def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
properties = []
@ -2234,7 +2241,11 @@ class Parser(metaclass=_Parser):
self._match(TokenType.TABLE)
is_function = self._match(TokenType.FUNCTION)
this = self._parse_table(schema=True) if not is_function else self._parse_function()
this = (
self._parse_table(schema=True, parse_partition=True)
if not is_function
else self._parse_function()
)
returning = self._parse_returning()
@ -2244,9 +2255,9 @@ class Parser(metaclass=_Parser):
hint=hint,
is_function=is_function,
this=this,
stored=self._match_text_seq("STORED") and self._parse_stored(),
by_name=self._match_text_seq("BY", "NAME"),
exists=self._parse_exists(),
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
and self._parse_conjunction(),
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
@ -3098,6 +3109,9 @@ class Parser(metaclass=_Parser):
else:
table = exp.Identifier(this="*")
# We bubble up comments from the Identifier to the Table
comments = table.pop_comments() if isinstance(table, exp.Expression) else None
if is_db_reference:
catalog = db
db = table
@ -3109,7 +3123,12 @@ class Parser(metaclass=_Parser):
self.raise_error(f"Expected database name but got {self._curr}")
return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
exp.Table,
comments=comments,
this=table,
db=db,
catalog=catalog,
pivots=self._parse_pivots(),
)
def _parse_table(
@ -3119,6 +3138,7 @@ class Parser(metaclass=_Parser):
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
parse_partition: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@ -3157,6 +3177,10 @@ class Parser(metaclass=_Parser):
# Postgres supports a wildcard (table) suffix operator, which is a no-op in this context
self._match_text_seq("*")
parse_partition = parse_partition or self.SUPPORTS_PARTITION_SELECTION
if parse_partition and self._match(TokenType.PARTITION, advance=False):
this.set("partition", self._parse_partition())
if schema:
return self._parse_schema(this=this)
@ -4200,7 +4224,11 @@ class Parser(metaclass=_Parser):
):
this = self._parse_id_var()
return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this
if isinstance(this, exp.Identifier):
# We bubble up comments from the Identifier to the Column
this = self.expression(exp.Column, comments=this.pop_comments(), this=this)
return this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@ -4216,7 +4244,7 @@ class Parser(metaclass=_Parser):
elif op and self._curr:
field = self._parse_column_reference()
else:
field = self._parse_field(anonymous_func=True, any_token=True)
field = self._parse_field(any_token=True, anonymous_func=True)
if isinstance(field, exp.Func) and this:
# bigquery allows function calls like x.y.count(...)
@ -4285,7 +4313,7 @@ class Parser(metaclass=_Parser):
this = self._parse_subquery(
this=self._parse_set_operations(this), parse_alias=False
)
elif len(expressions) > 1:
elif len(expressions) > 1 or self._prev.token_type == TokenType.COMMA:
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
@ -4304,17 +4332,23 @@ class Parser(metaclass=_Parser):
tokens: t.Optional[t.Collection[TokenType]] = None,
anonymous_func: bool = False,
) -> t.Optional[exp.Expression]:
return (
self._parse_primary()
or self._parse_function(anonymous=anonymous_func)
or self._parse_id_var(any_token=any_token, tokens=tokens)
)
if anonymous_func:
field = (
self._parse_function(anonymous=anonymous_func, any_token=any_token)
or self._parse_primary()
)
else:
field = self._parse_primary() or self._parse_function(
anonymous=anonymous_func, any_token=any_token
)
return field or self._parse_id_var(any_token=any_token, tokens=tokens)
def _parse_function(
self,
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
any_token: bool = False,
) -> t.Optional[exp.Expression]:
# This allows us to also parse {fn <function>} syntax (Snowflake, MySQL support this)
# See: https://community.snowflake.com/s/article/SQL-Escape-Sequences
@ -4328,7 +4362,10 @@ class Parser(metaclass=_Parser):
fn_syntax = True
func = self._parse_function_call(
functions=functions, anonymous=anonymous, optional_parens=optional_parens
functions=functions,
anonymous=anonymous,
optional_parens=optional_parens,
any_token=any_token,
)
if fn_syntax:
@ -4341,6 +4378,7 @@ class Parser(metaclass=_Parser):
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
any_token: bool = False,
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@ -4362,7 +4400,10 @@ class Parser(metaclass=_Parser):
return None
if token_type not in self.FUNC_TOKENS:
if any_token:
if token_type in self.RESERVED_TOKENS:
return None
elif token_type not in self.FUNC_TOKENS:
return None
self._advance(2)
@ -4501,7 +4542,6 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
if not self._match(TokenType.L_PAREN):
return this
@ -4510,9 +4550,7 @@ class Parser(metaclass=_Parser):
if self._match_set(self.SELECT_START_TOKENS):
self._retreat(index)
return this
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@ -5378,8 +5416,8 @@ class Parser(metaclass=_Parser):
else:
over = self._prev.text.upper()
if comments:
func.comments = None # type: ignore
if comments and isinstance(func, exp.Expression):
func.pop_comments()
if not self._match(TokenType.L_PAREN):
return self.expression(
@ -5457,7 +5495,7 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression], explicit: bool = False
) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
comments = self._prev_comments
comments = self._prev_comments or []
if explicit and not any_token:
return this
@ -5477,13 +5515,13 @@ class Parser(metaclass=_Parser):
)
if alias:
comments.extend(alias.pop_comments())
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
column = this.this
# Moves the comment next to the alias in `expr /* comment */ AS alias`
if not this.comments and column and column.comments:
this.comments = column.comments
column.comments = None
this.comments = column.pop_comments()
return this
@ -5492,16 +5530,14 @@ class Parser(metaclass=_Parser):
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
identifier = self._parse_identifier()
if identifier:
return identifier
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
expression = self._parse_identifier()
if not expression and (
(any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS)
):
quoted = self._prev.token_type == TokenType.STRING
return exp.Identifier(this=self._prev.text, quoted=quoted)
expression = self.expression(exp.Identifier, this=self._prev.text, quoted=quoted)
return None
return expression
def _parse_string(self) -> t.Optional[exp.Expression]:
if self._match_set(self.STRING_PARSERS):

View file

@ -567,11 +567,11 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"@": TokenType.PARAMETER,
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
"#": TokenType.HASH,
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
"'": TokenType.UNKNOWN,
"`": TokenType.UNKNOWN,
'"': TokenType.UNKNOWN,
}
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []

View file

@ -93,7 +93,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
Some dialects don't support window functions in the WHERE clause, so we need to include them as
projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
otherwise we won't be able to refer to it in the outer query's WHERE clause.
otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
corresponding expression to avoid creating invalid column references.
"""
if isinstance(expression, exp.Select) and expression.args.get("qualify"):
taken = set(expression.named_selects)
@ -105,20 +107,31 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
qualify_filters = expression.args["qualify"].pop().this
expression_by_alias = {
select.alias: select.this
for select in expression.selects
if isinstance(select, exp.Alias)
}
select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
for expr in qualify_filters.find_all(select_candidates):
if isinstance(expr, exp.Window):
for select_candidate in qualify_filters.find_all(select_candidates):
if isinstance(select_candidate, exp.Window):
if expression_by_alias:
for column in select_candidate.find_all(exp.Column):
expr = expression_by_alias.get(column.name)
if expr:
column.replace(expr)
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr, alias), copy=False)
expression.select(exp.alias_(select_candidate, alias), copy=False)
column = exp.column(alias)
if isinstance(expr.parent, exp.Qualify):
if isinstance(select_candidate.parent, exp.Qualify):
qualify_filters = column
else:
expr.replace(column)
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
select_candidate.replace(column)
elif select_candidate.name not in expression.named_selects:
expression.select(select_candidate.copy(), copy=False)
return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
qualify_filters, copy=False
@ -336,13 +349,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
return _explode_to_unnest
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
"""Transforms percentiles by adding a WITHIN GROUP clause to them."""
if (
isinstance(expression, PERCENTILES)
isinstance(expression, exp.PERCENTILES)
and not isinstance(expression.parent, exp.WithinGroup)
and expression.expression
):
@ -358,7 +368,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
"""Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
if (
isinstance(expression, exp.WithinGroup)
and isinstance(expression.this, PERCENTILES)
and isinstance(expression.this, exp.PERCENTILES)
and isinstance(expression.expression, exp.Order)
):
quantile = expression.this.this