Merging upstream version 23.12.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
1271e5fe1c
commit
740634a4e8
93 changed files with 55455 additions and 52777 deletions
|
@ -222,7 +222,6 @@ class BigQuery(Dialect):
|
|||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time
|
||||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
"%E*S": "%S.%f",
|
||||
"%E6S": "%S.%f",
|
||||
}
|
||||
|
||||
|
@ -474,11 +473,31 @@ class BigQuery(Dialect):
|
|||
if rest and this:
|
||||
this = exp.Dot.build([this, *rest]) # type: ignore
|
||||
|
||||
table = exp.Table(this=this, db=db, catalog=catalog)
|
||||
table = exp.Table(
|
||||
this=this, db=db, catalog=catalog, pivots=table.args.get("pivots")
|
||||
)
|
||||
table.meta["quoted_table"] = True
|
||||
|
||||
return table
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
column = super()._parse_column()
|
||||
if isinstance(column, exp.Column):
|
||||
parts = column.parts
|
||||
if any("." in p.name for p in parts):
|
||||
catalog, db, table, this, *rest = (
|
||||
exp.to_identifier(p, quoted=True)
|
||||
for p in split_num_words(".".join(p.name for p in parts), ".", 4)
|
||||
)
|
||||
|
||||
if rest and this:
|
||||
this = exp.Dot.build([this, *rest]) # type: ignore
|
||||
|
||||
column = exp.Column(this=this, table=table, db=db, catalog=catalog)
|
||||
column.meta["quoted_column"] = True
|
||||
|
||||
return column
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
|
||||
|
@ -670,6 +689,7 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARBINARY: "BYTES",
|
||||
exp.DataType.Type.ROWVERSION: "BYTES",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
exp.DataType.Type.VARIANT: "ANY TYPE",
|
||||
}
|
||||
|
@ -781,6 +801,16 @@ class BigQuery(Dialect):
|
|||
"within",
|
||||
}
|
||||
|
||||
def column_parts(self, expression: exp.Column) -> str:
|
||||
if expression.meta.get("quoted_column"):
|
||||
# If a column reference is of the form `dataset.table`.name, we need
|
||||
# to preserve the quoted table path, otherwise the reference breaks
|
||||
table_parts = ".".join(p.name for p in expression.parts[:-1])
|
||||
table_path = self.sql(exp.Identifier(this=table_parts, quoted=True))
|
||||
return f"{table_path}.{self.sql(expression, 'this')}"
|
||||
|
||||
return super().column_parts(expression)
|
||||
|
||||
def table_parts(self, expression: exp.Table) -> str:
|
||||
# Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so
|
||||
# we need to make sure the correct quoting is used in each case.
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arg_max_or_min_no_count,
|
||||
build_formatted_time,
|
||||
date_delta_sql,
|
||||
inline_array_sql,
|
||||
json_extract_segments,
|
||||
|
@ -19,6 +20,16 @@ from sqlglot.helper import is_int, seq_get
|
|||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
|
||||
def _build_date_format(args: t.List) -> exp.TimeToStr:
|
||||
expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args)
|
||||
|
||||
timezone = seq_get(args, 2)
|
||||
if timezone:
|
||||
expr.set("timezone", timezone)
|
||||
|
||||
return expr
|
||||
|
||||
|
||||
def _lower_func(sql: str) -> str:
|
||||
index = sql.index("(")
|
||||
return sql[:index].lower() + sql[index:]
|
||||
|
@ -124,6 +135,8 @@ class ClickHouse(Dialect):
|
|||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_FORMAT": _build_date_format,
|
||||
"FORMATDATETIME": _build_date_format,
|
||||
"JSONEXTRACTSTRING": build_json_extract_path(
|
||||
exp.JSONExtractScalar, zero_based_indexing=False
|
||||
),
|
||||
|
@ -241,6 +254,14 @@ class ClickHouse(Dialect):
|
|||
"sparkBar",
|
||||
"sumCount",
|
||||
"largestTriangleThreeBuckets",
|
||||
"histogram",
|
||||
"sequenceMatch",
|
||||
"sequenceCount",
|
||||
"windowFunnel",
|
||||
"retention",
|
||||
"uniqUpTo",
|
||||
"sequenceNextNode",
|
||||
"exponentialTimeDecayedAvg",
|
||||
}
|
||||
|
||||
AGG_FUNCTIONS_SUFFIXES = [
|
||||
|
@ -383,6 +404,7 @@ class ClickHouse(Dialect):
|
|||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
parse_partition: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table(
|
||||
schema=schema,
|
||||
|
@ -447,46 +469,53 @@ class ClickHouse(Dialect):
|
|||
functions: t.Optional[t.Dict[str, t.Callable]] = None,
|
||||
anonymous: bool = False,
|
||||
optional_parens: bool = True,
|
||||
any_token: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
func = super()._parse_function(
|
||||
functions=functions, anonymous=anonymous, optional_parens=optional_parens
|
||||
expr = super()._parse_function(
|
||||
functions=functions,
|
||||
anonymous=anonymous,
|
||||
optional_parens=optional_parens,
|
||||
any_token=any_token,
|
||||
)
|
||||
|
||||
if isinstance(func, exp.Anonymous):
|
||||
parts = self.AGG_FUNC_MAPPING.get(func.this)
|
||||
func = expr.this if isinstance(expr, exp.Window) else expr
|
||||
|
||||
# Aggregate functions can be split in 2 parts: <func_name><suffix>
|
||||
parts = (
|
||||
self.AGG_FUNC_MAPPING.get(func.this) if isinstance(func, exp.Anonymous) else None
|
||||
)
|
||||
|
||||
if parts:
|
||||
params = self._parse_func_params(func)
|
||||
|
||||
kwargs = {
|
||||
"this": func.this,
|
||||
"expressions": func.expressions,
|
||||
}
|
||||
if parts[1]:
|
||||
kwargs["parts"] = parts
|
||||
exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc
|
||||
else:
|
||||
exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc
|
||||
|
||||
kwargs["exp_class"] = exp_class
|
||||
if params:
|
||||
if parts and parts[1]:
|
||||
return self.expression(
|
||||
exp.CombinedParameterizedAgg,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
params=params,
|
||||
parts=parts,
|
||||
)
|
||||
return self.expression(
|
||||
exp.ParameterizedAgg,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
params=params,
|
||||
)
|
||||
kwargs["params"] = params
|
||||
|
||||
if parts:
|
||||
if parts[1]:
|
||||
return self.expression(
|
||||
exp.CombinedAggFunc,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
parts=parts,
|
||||
)
|
||||
return self.expression(
|
||||
exp.AnonymousAggFunc,
|
||||
this=func.this,
|
||||
expressions=func.expressions,
|
||||
)
|
||||
func = self.expression(**kwargs)
|
||||
|
||||
return func
|
||||
if isinstance(expr, exp.Window):
|
||||
# The window's func was parsed as Anonymous in base parser, fix its
|
||||
# type to be CH style CombinedAnonymousAggFunc / AnonymousAggFunc
|
||||
expr.set("this", func)
|
||||
elif params:
|
||||
# Params have blocked super()._parse_function() from parsing the following window
|
||||
# (if that exists) as they're standing between the function call and the window spec
|
||||
expr = self._parse_window(func)
|
||||
else:
|
||||
expr = func
|
||||
|
||||
return expr
|
||||
|
||||
def _parse_func_params(
|
||||
self, this: t.Optional[exp.Func] = None
|
||||
|
@ -653,6 +682,9 @@ class ClickHouse(Dialect):
|
|||
exp.StrPosition: lambda self, e: self.func(
|
||||
"position", e.this, e.args.get("substr"), e.args.get("position")
|
||||
),
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone")
|
||||
),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
|
||||
}
|
||||
|
|
|
@ -568,7 +568,7 @@ def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> st
|
|||
|
||||
|
||||
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
||||
return f"[{self.expressions(expression, flat=True)}]"
|
||||
return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]"
|
||||
|
||||
|
||||
def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
|
||||
|
|
|
@ -28,7 +28,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestrtotime_sql,
|
||||
unit_to_var,
|
||||
)
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -155,16 +155,6 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
|
|||
return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)))
|
||||
|
||||
|
||||
def _rename_unless_within_group(
|
||||
a: str, b: str
|
||||
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
|
||||
return lambda self, expression: (
|
||||
self.func(a, *flatten(expression.args.values()))
|
||||
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
|
||||
else self.func(b, *flatten(expression.args.values()))
|
||||
)
|
||||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
@ -425,8 +415,8 @@ class DuckDB(Dialect):
|
|||
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
|
||||
),
|
||||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
|
||||
exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
|
||||
exp.PercentileCont: rename_func("QUANTILE_CONT"),
|
||||
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
|
||||
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
|
||||
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
|
||||
|
@ -499,6 +489,7 @@ class DuckDB(Dialect):
|
|||
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||
exp.DataType.Type.UINT: "UINTEGER",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
exp.DataType.Type.ROWVERSION: "BLOB",
|
||||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
|
||||
exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
|
||||
|
@ -616,3 +607,19 @@ class DuckDB(Dialect):
|
|||
bracket = f"({bracket})[1]"
|
||||
|
||||
return bracket
|
||||
|
||||
def withingroup_sql(self, expression: exp.WithinGroup) -> str:
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
|
||||
func = expression.this
|
||||
if isinstance(func, exp.PERCENTILES):
|
||||
# Make the order key the first arg and slide the fraction to the right
|
||||
# https://duckdb.org/docs/sql/aggregates#ordered-set-aggregate-functions
|
||||
order_col = expression.find(exp.Ordered)
|
||||
if order_col:
|
||||
func.set("expression", func.this)
|
||||
func.set("this", order_col.this)
|
||||
|
||||
this = self.sql(expression, "this").rstrip(")")
|
||||
|
||||
return f"{this}{expression_sql})"
|
||||
|
|
|
@ -457,6 +457,7 @@ class Hive(Dialect):
|
|||
exp.DataType.Type.TIME: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
exp.DataType.Type.ROWVERSION: "BINARY",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -443,6 +443,7 @@ class MySQL(Dialect):
|
|||
LOG_DEFAULTS_TO_LN = True
|
||||
STRING_ALIASES = True
|
||||
VALUES_FOLLOWED_BY_PAREN = False
|
||||
SUPPORTS_PARTITION_SELECTION = True
|
||||
|
||||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_id_var()
|
||||
|
|
|
@ -222,6 +222,7 @@ class Oracle(Dialect):
|
|||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
exp.DataType.Type.ROWVERSION: "BLOB",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -431,6 +431,7 @@ class Postgres(Dialect):
|
|||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
exp.DataType.Type.BINARY: "BYTEA",
|
||||
exp.DataType.Type.VARBINARY: "BYTEA",
|
||||
exp.DataType.Type.ROWVERSION: "BYTEA",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
}
|
||||
|
||||
|
|
|
@ -443,6 +443,67 @@ class Presto(Dialect):
|
|||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
RESERVED_KEYWORDS = {
|
||||
"alter",
|
||||
"and",
|
||||
"as",
|
||||
"between",
|
||||
"by",
|
||||
"case",
|
||||
"cast",
|
||||
"constraint",
|
||||
"create",
|
||||
"cross",
|
||||
"current_time",
|
||||
"current_timestamp",
|
||||
"deallocate",
|
||||
"delete",
|
||||
"describe",
|
||||
"distinct",
|
||||
"drop",
|
||||
"else",
|
||||
"end",
|
||||
"escape",
|
||||
"except",
|
||||
"execute",
|
||||
"exists",
|
||||
"extract",
|
||||
"false",
|
||||
"for",
|
||||
"from",
|
||||
"full",
|
||||
"group",
|
||||
"having",
|
||||
"in",
|
||||
"inner",
|
||||
"insert",
|
||||
"intersect",
|
||||
"into",
|
||||
"is",
|
||||
"join",
|
||||
"left",
|
||||
"like",
|
||||
"natural",
|
||||
"not",
|
||||
"null",
|
||||
"on",
|
||||
"or",
|
||||
"order",
|
||||
"outer",
|
||||
"prepare",
|
||||
"right",
|
||||
"select",
|
||||
"table",
|
||||
"then",
|
||||
"true",
|
||||
"union",
|
||||
"using",
|
||||
"values",
|
||||
"when",
|
||||
"where",
|
||||
"with",
|
||||
}
|
||||
|
||||
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
|
||||
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
|
||||
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
|
||||
|
|
|
@ -55,6 +55,20 @@ class PRQL(Dialect):
|
|||
"SORT": lambda self, query: self._parse_order_by(query),
|
||||
}
|
||||
|
||||
def _parse_equality(self) -> t.Optional[exp.Expression]:
|
||||
eq = self._parse_tokens(self._parse_comparison, self.EQUALITY)
|
||||
if not isinstance(eq, (exp.EQ, exp.NEQ)):
|
||||
return eq
|
||||
|
||||
# https://prql-lang.org/book/reference/spec/null.html
|
||||
if isinstance(eq.expression, exp.Null):
|
||||
is_exp = exp.Is(this=eq.this, expression=eq.expression)
|
||||
return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp)
|
||||
if isinstance(eq.this, exp.Null):
|
||||
is_exp = exp.Is(this=eq.expression, expression=eq.this)
|
||||
return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp)
|
||||
return eq
|
||||
|
||||
def _parse_statement(self) -> t.Optional[exp.Expression]:
|
||||
expression = self._parse_expression()
|
||||
expression = expression if expression else self._parse_query()
|
||||
|
@ -136,6 +150,7 @@ class PRQL(Dialect):
|
|||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
parse_partition: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return self._parse_table_parts()
|
||||
|
||||
|
|
|
@ -79,6 +79,7 @@ class Redshift(Postgres):
|
|||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
parse_partition: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
|
||||
unpivot = self._match(TokenType.UNPIVOT)
|
||||
|
@ -145,6 +146,7 @@ class Redshift(Postgres):
|
|||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
exp.DataType.Type.ROWVERSION: "VARBYTE",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -196,7 +198,165 @@ class Redshift(Postgres):
|
|||
# Redshift supports LAST_DAY(..)
|
||||
TRANSFORMS.pop(exp.LastDay)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
RESERVED_KEYWORDS = {
|
||||
"aes128",
|
||||
"aes256",
|
||||
"all",
|
||||
"allowoverwrite",
|
||||
"analyse",
|
||||
"analyze",
|
||||
"and",
|
||||
"any",
|
||||
"array",
|
||||
"as",
|
||||
"asc",
|
||||
"authorization",
|
||||
"az64",
|
||||
"backup",
|
||||
"between",
|
||||
"binary",
|
||||
"blanksasnull",
|
||||
"both",
|
||||
"bytedict",
|
||||
"bzip2",
|
||||
"case",
|
||||
"cast",
|
||||
"check",
|
||||
"collate",
|
||||
"column",
|
||||
"constraint",
|
||||
"create",
|
||||
"credentials",
|
||||
"cross",
|
||||
"current_date",
|
||||
"current_time",
|
||||
"current_timestamp",
|
||||
"current_user",
|
||||
"current_user_id",
|
||||
"default",
|
||||
"deferrable",
|
||||
"deflate",
|
||||
"defrag",
|
||||
"delta",
|
||||
"delta32k",
|
||||
"desc",
|
||||
"disable",
|
||||
"distinct",
|
||||
"do",
|
||||
"else",
|
||||
"emptyasnull",
|
||||
"enable",
|
||||
"encode",
|
||||
"encrypt ",
|
||||
"encryption",
|
||||
"end",
|
||||
"except",
|
||||
"explicit",
|
||||
"false",
|
||||
"for",
|
||||
"foreign",
|
||||
"freeze",
|
||||
"from",
|
||||
"full",
|
||||
"globaldict256",
|
||||
"globaldict64k",
|
||||
"grant",
|
||||
"group",
|
||||
"gzip",
|
||||
"having",
|
||||
"identity",
|
||||
"ignore",
|
||||
"ilike",
|
||||
"in",
|
||||
"initially",
|
||||
"inner",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"is",
|
||||
"isnull",
|
||||
"join",
|
||||
"leading",
|
||||
"left",
|
||||
"like",
|
||||
"limit",
|
||||
"localtime",
|
||||
"localtimestamp",
|
||||
"lun",
|
||||
"luns",
|
||||
"lzo",
|
||||
"lzop",
|
||||
"minus",
|
||||
"mostly16",
|
||||
"mostly32",
|
||||
"mostly8",
|
||||
"natural",
|
||||
"new",
|
||||
"not",
|
||||
"notnull",
|
||||
"null",
|
||||
"nulls",
|
||||
"off",
|
||||
"offline",
|
||||
"offset",
|
||||
"oid",
|
||||
"old",
|
||||
"on",
|
||||
"only",
|
||||
"open",
|
||||
"or",
|
||||
"order",
|
||||
"outer",
|
||||
"overlaps",
|
||||
"parallel",
|
||||
"partition",
|
||||
"percent",
|
||||
"permissions",
|
||||
"pivot",
|
||||
"placing",
|
||||
"primary",
|
||||
"raw",
|
||||
"readratio",
|
||||
"recover",
|
||||
"references",
|
||||
"rejectlog",
|
||||
"resort",
|
||||
"respect",
|
||||
"restore",
|
||||
"right",
|
||||
"select",
|
||||
"session_user",
|
||||
"similar",
|
||||
"snapshot",
|
||||
"some",
|
||||
"sysdate",
|
||||
"system",
|
||||
"table",
|
||||
"tag",
|
||||
"tdes",
|
||||
"text255",
|
||||
"text32k",
|
||||
"then",
|
||||
"timestamp",
|
||||
"to",
|
||||
"top",
|
||||
"trailing",
|
||||
"true",
|
||||
"truncatecolumns",
|
||||
"type",
|
||||
"union",
|
||||
"unique",
|
||||
"unnest",
|
||||
"unpivot",
|
||||
"user",
|
||||
"using",
|
||||
"verbose",
|
||||
"wallet",
|
||||
"when",
|
||||
"where",
|
||||
"with",
|
||||
"without",
|
||||
}
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
args = expression.expressions
|
||||
|
|
|
@ -33,10 +33,9 @@ def _build_datetime(
|
|||
) -> t.Callable[[t.List], exp.Func]:
|
||||
def _builder(args: t.List) -> exp.Func:
|
||||
value = seq_get(args, 0)
|
||||
int_value = value is not None and is_int(value.name)
|
||||
|
||||
if isinstance(value, exp.Literal):
|
||||
int_value = is_int(value.this)
|
||||
|
||||
# Converts calls like `TO_TIME('01:02:03')` into casts
|
||||
if len(args) == 1 and value.is_string and not int_value:
|
||||
return exp.cast(value, kind)
|
||||
|
@ -49,7 +48,7 @@ def _build_datetime(
|
|||
if not is_float(value.this):
|
||||
return build_formatted_time(exp.StrToTime, "snowflake")(args)
|
||||
|
||||
if len(args) == 2 and kind == exp.DataType.Type.DATE:
|
||||
if kind == exp.DataType.Type.DATE and not int_value:
|
||||
formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
|
||||
formatted_exp.set("safe", safe)
|
||||
return formatted_exp
|
||||
|
@ -749,6 +748,7 @@ class Snowflake(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArgMax: rename_func("MAX_BY"),
|
||||
exp.ArgMin: rename_func("MIN_BY"),
|
||||
exp.Array: inline_array_sql,
|
||||
|
|
|
@ -464,6 +464,7 @@ class TSQL(Dialect):
|
|||
"SMALLMONEY": TokenType.SMALLMONEY,
|
||||
"SQL_VARIANT": TokenType.VARIANT,
|
||||
"TOP": TokenType.TOP,
|
||||
"TIMESTAMP": TokenType.ROWVERSION,
|
||||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"UPDATE STATISTICS": TokenType.COMMAND,
|
||||
"XML": TokenType.XML,
|
||||
|
@ -755,6 +756,7 @@ class TSQL(Dialect):
|
|||
exp.DataType.Type.TIMESTAMP: "DATETIME2",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
|
||||
exp.DataType.Type.VARIANT: "SQL_VARIANT",
|
||||
exp.DataType.Type.ROWVERSION: "ROWVERSION",
|
||||
}
|
||||
|
||||
TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
|
||||
|
@ -1052,3 +1054,9 @@ class TSQL(Dialect):
|
|||
|
||||
def partition_sql(self, expression: exp.Partition) -> str:
|
||||
return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"
|
||||
|
||||
def altertable_sql(self, expression: exp.AlterTable) -> str:
|
||||
action = seq_get(expression.args.get("actions") or [], 0)
|
||||
if isinstance(action, exp.RenameTable):
|
||||
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
|
||||
return super().altertable_sql(expression)
|
||||
|
|
|
@ -301,9 +301,10 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return deepcopy(self)
|
||||
|
||||
def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
|
||||
def add_comments(self, comments: t.Optional[t.List[str]] = None) -> None:
|
||||
if self.comments is None:
|
||||
self.comments = []
|
||||
|
||||
if comments:
|
||||
for comment in comments:
|
||||
_, *meta = comment.split(SQLGLOT_META)
|
||||
|
@ -314,6 +315,11 @@ class Expression(metaclass=_Expression):
|
|||
self.meta[k.strip()] = value
|
||||
self.comments.append(comment)
|
||||
|
||||
def pop_comments(self) -> t.List[str]:
|
||||
comments = self.comments or []
|
||||
self.comments = None
|
||||
return comments
|
||||
|
||||
def append(self, arg_key: str, value: t.Any) -> None:
|
||||
"""
|
||||
Appends value to arg_key if it's a list or sets it as a new list.
|
||||
|
@ -2058,11 +2064,11 @@ class Insert(DDL, DML):
|
|||
"returning": False,
|
||||
"overwrite": False,
|
||||
"exists": False,
|
||||
"partition": False,
|
||||
"alternative": False,
|
||||
"where": False,
|
||||
"ignore": False,
|
||||
"by_name": False,
|
||||
"stored": False,
|
||||
}
|
||||
|
||||
def with_(
|
||||
|
@ -2911,6 +2917,7 @@ class Table(Expression):
|
|||
"ordinality": False,
|
||||
"when": False,
|
||||
"only": False,
|
||||
"partition": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -5683,7 +5690,7 @@ class StddevSamp(AggFunc):
|
|||
|
||||
|
||||
class TimeToStr(Func):
|
||||
arg_types = {"this": True, "format": True, "culture": False}
|
||||
arg_types = {"this": True, "format": True, "culture": False, "timezone": False}
|
||||
|
||||
|
||||
class TimeToTimeStr(Func):
|
||||
|
@ -5873,6 +5880,8 @@ FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_na
|
|||
|
||||
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
|
||||
|
||||
PERCENTILES = (PercentileCont, PercentileDisc)
|
||||
|
||||
|
||||
# Helpers
|
||||
@t.overload
|
||||
|
|
|
@ -349,6 +349,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.DataType.Type.LONGBLOB: "BLOB",
|
||||
exp.DataType.Type.TINYBLOB: "BLOB",
|
||||
exp.DataType.Type.INET: "INET",
|
||||
exp.DataType.Type.ROWVERSION: "VARBINARY",
|
||||
}
|
||||
|
||||
STAR_MAPPING = {
|
||||
|
@ -644,6 +645,7 @@ class Generator(metaclass=_Generator):
|
|||
sql: str,
|
||||
expression: t.Optional[exp.Expression] = None,
|
||||
comments: t.Optional[t.List[str]] = None,
|
||||
separated: bool = False,
|
||||
) -> str:
|
||||
comments = (
|
||||
((expression and expression.comments) if comments is None else comments) # type: ignore
|
||||
|
@ -661,7 +663,9 @@ class Generator(metaclass=_Generator):
|
|||
if not comments_sql:
|
||||
return sql
|
||||
|
||||
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
comments_sql = self._replace_line_breaks(comments_sql)
|
||||
|
||||
if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS):
|
||||
return (
|
||||
f"{self.sep()}{comments_sql}{sql}"
|
||||
if not sql or sql[0].isspace()
|
||||
|
@ -778,14 +782,8 @@ class Generator(metaclass=_Generator):
|
|||
default = "DEFAULT " if expression.args.get("default") else ""
|
||||
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
|
||||
|
||||
def column_sql(self, expression: exp.Column) -> str:
|
||||
join_mark = " (+)" if expression.args.get("join_mark") else ""
|
||||
|
||||
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
|
||||
join_mark = ""
|
||||
self.unsupported("Outer join syntax using the (+) operator is not supported.")
|
||||
|
||||
column = ".".join(
|
||||
def column_parts(self, expression: exp.Column) -> str:
|
||||
return ".".join(
|
||||
self.sql(part)
|
||||
for part in (
|
||||
expression.args.get("catalog"),
|
||||
|
@ -796,7 +794,14 @@ class Generator(metaclass=_Generator):
|
|||
if part
|
||||
)
|
||||
|
||||
return f"{column}{join_mark}"
|
||||
def column_sql(self, expression: exp.Column) -> str:
|
||||
join_mark = " (+)" if expression.args.get("join_mark") else ""
|
||||
|
||||
if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED:
|
||||
join_mark = ""
|
||||
self.unsupported("Outer join syntax using the (+) operator is not supported.")
|
||||
|
||||
return f"{self.column_parts(expression)}{join_mark}"
|
||||
|
||||
def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1520,6 +1525,8 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
this = self.INSERT_OVERWRITE if overwrite else " INTO"
|
||||
|
||||
stored = self.sql(expression, "stored")
|
||||
stored = f" {stored}" if stored else ""
|
||||
alternative = expression.args.get("alternative")
|
||||
alternative = f" OR {alternative}" if alternative else ""
|
||||
ignore = " IGNORE" if expression.args.get("ignore") else ""
|
||||
|
@ -1529,9 +1536,6 @@ class Generator(metaclass=_Generator):
|
|||
this = f"{this} {self.sql(expression, 'this')}"
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
partition_sql = (
|
||||
f" {self.sql(expression, 'partition')}" if expression.args.get("partition") else ""
|
||||
)
|
||||
where = self.sql(expression, "where")
|
||||
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
|
||||
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
|
||||
|
@ -1545,7 +1549,7 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
expression_sql = f"{returning}{expression_sql}{on_conflict}"
|
||||
|
||||
sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
|
||||
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{where}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1634,6 +1638,8 @@ class Generator(metaclass=_Generator):
|
|||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = self.table_parts(expression)
|
||||
only = "ONLY " if expression.args.get("only") else ""
|
||||
partition = self.sql(expression, "partition")
|
||||
partition = f" {partition}" if partition else ""
|
||||
version = self.sql(expression, "version")
|
||||
version = f" {version}" if version else ""
|
||||
alias = self.sql(expression, "alias")
|
||||
|
@ -1662,7 +1668,7 @@ class Generator(metaclass=_Generator):
|
|||
if when:
|
||||
table = f"{table} {when}"
|
||||
|
||||
return f"{only}{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
return f"{only}{table}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
self,
|
||||
|
@ -2017,10 +2023,9 @@ class Generator(metaclass=_Generator):
|
|||
to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text
|
||||
)
|
||||
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
|
||||
return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end)
|
||||
return self._replace_line_breaks(text).replace(
|
||||
self.dialect.QUOTE_END, self._escaped_quote_end
|
||||
)
|
||||
|
||||
def loaddata_sql(self, expression: exp.LoadData) -> str:
|
||||
local = " LOCAL" if expression.args.get("local") else ""
|
||||
|
@ -2341,8 +2346,8 @@ class Generator(metaclass=_Generator):
|
|||
stack.append(
|
||||
self.maybe_comment(
|
||||
getattr(self, f"{node.key}_op")(node),
|
||||
expression=node.this,
|
||||
comments=node.comments,
|
||||
separated=True,
|
||||
)
|
||||
)
|
||||
stack.append(node.this)
|
||||
|
@ -2486,7 +2491,7 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
statements.append("END")
|
||||
|
||||
if self.pretty and self.text_width(statements) > self.max_text_width:
|
||||
if self.pretty and self.too_wide(statements):
|
||||
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
|
||||
|
||||
return " ".join(statements)
|
||||
|
@ -2847,7 +2852,7 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
sqls.append(sql)
|
||||
|
||||
sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " "
|
||||
sep = "\n" if self.pretty and self.too_wide(sqls) else " "
|
||||
return sep.join(sqls)
|
||||
|
||||
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
|
||||
|
@ -3208,12 +3213,12 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
|
||||
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
|
||||
if self.pretty and self.too_wide(arg_sqls):
|
||||
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(arg_sqls)
|
||||
|
||||
def text_width(self, args: t.Iterable) -> int:
|
||||
return sum(len(arg) for arg in args)
|
||||
def too_wide(self, args: t.Iterable) -> bool:
|
||||
return sum(len(arg) for arg in args) > self.max_text_width
|
||||
|
||||
def format_time(
|
||||
self,
|
||||
|
@ -3235,8 +3240,11 @@ class Generator(metaclass=_Generator):
|
|||
flat: bool = False,
|
||||
indent: bool = True,
|
||||
skip_first: bool = False,
|
||||
skip_last: bool = False,
|
||||
sep: str = ", ",
|
||||
prefix: str = "",
|
||||
dynamic: bool = False,
|
||||
new_line: bool = False,
|
||||
) -> str:
|
||||
expressions = expression.args.get(key or "expressions") if expression else sqls
|
||||
|
||||
|
@ -3270,8 +3278,18 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
|
||||
|
||||
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||
return self.indent(result_sql, skip_first=skip_first) if indent else result_sql
|
||||
if self.pretty and (not dynamic or self.too_wide(result_sqls)):
|
||||
if new_line:
|
||||
result_sqls.insert(0, "")
|
||||
result_sqls.append("")
|
||||
result_sql = "\n".join(result_sqls)
|
||||
else:
|
||||
result_sql = "".join(result_sqls)
|
||||
return (
|
||||
self.indent(result_sql, skip_first=skip_first, skip_last=skip_last)
|
||||
if indent
|
||||
else result_sql
|
||||
)
|
||||
|
||||
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
|
||||
flat = flat or isinstance(expression.parent, exp.Properties)
|
||||
|
@ -3733,3 +3751,9 @@ class Generator(metaclass=_Generator):
|
|||
return self.sql(agg_func)[:-1] + f" {text})"
|
||||
|
||||
return f"{self.sql(expression, 'this')} {text}"
|
||||
|
||||
def _replace_line_breaks(self, string: str) -> str:
|
||||
"""We don't want to extra indent line breaks so we temporarily replace them with sentinels."""
|
||||
if self.pretty:
|
||||
return string.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return string
|
||||
|
|
|
@ -351,55 +351,57 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def annotate(self, expression: E) -> E:
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
self.annotate_scope(scope)
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def annotate_scope(self, scope: Scope) -> None:
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if id(expression) in self._visited:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
@ -601,7 +603,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
|
||||
self._annotate_args(expression)
|
||||
child = seq_get(expression.expressions, 0)
|
||||
self._set_type(expression, child and seq_get(child.type.expressions, 0))
|
||||
|
||||
if child and child.is_type(exp.DataType.Type.ARRAY):
|
||||
expr_type = seq_get(child.type.expressions, 0)
|
||||
else:
|
||||
expr_type = None
|
||||
|
||||
self._set_type(expression, expr_type)
|
||||
return expression
|
||||
|
||||
def _annotate_struct_value(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
|
@ -85,7 +86,7 @@ def optimize(
|
|||
optimized = exp.maybe_parse(expression, dialect=dialect, copy=True)
|
||||
for rule in rules:
|
||||
# Find any additional rule parameters, beyond `expression`
|
||||
rule_params = rule.__code__.co_varnames
|
||||
rule_params = inspect.getfullargspec(rule).args
|
||||
rule_kwargs = {
|
||||
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ from sqlglot import alias, exp
|
|||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import seq_get, SingleValuedMapping
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
from sqlglot.optimizer.annotate_types import TypeAnnotator
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify_parens
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
@ -49,8 +49,10 @@ def qualify_columns(
|
|||
- Currently only handles a single PIVOT or UNPIVOT operator
|
||||
"""
|
||||
schema = ensure_schema(schema)
|
||||
annotator = TypeAnnotator(schema)
|
||||
infer_schema = schema.empty if infer_schema is None else infer_schema
|
||||
pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
|
||||
dialect = Dialect.get_or_raise(schema.dialect)
|
||||
pseudocolumns = dialect.PSEUDOCOLUMNS
|
||||
|
||||
for scope in traverse_scope(expression):
|
||||
resolver = Resolver(scope, schema, infer_schema=infer_schema)
|
||||
|
@ -74,6 +76,9 @@ def qualify_columns(
|
|||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
if dialect == "bigquery":
|
||||
annotator.annotate_scope(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -660,11 +665,8 @@ class Resolver:
|
|||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
expression = source.expression
|
||||
annotate_types(expression)
|
||||
|
||||
if expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in expression.type.expressions: # type: ignore
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
|
|
@ -6,6 +6,7 @@ import itertools
|
|||
import typing as t
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import Dialect, exp
|
||||
|
@ -658,17 +659,21 @@ def simplify_parens(expression):
|
|||
parent = expression.parent
|
||||
parent_is_predicate = isinstance(parent, exp.Predicate)
|
||||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
if (
|
||||
not isinstance(this, exp.Select)
|
||||
and not isinstance(parent, exp.SubqueryPredicate)
|
||||
and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(parent, exp.Paren)
|
||||
or (
|
||||
not isinstance(this, exp.Binary)
|
||||
and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
)
|
||||
or (isinstance(this, exp.Predicate) and not parent_is_predicate)
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
):
|
||||
return this
|
||||
return expression
|
||||
|
@ -779,6 +784,8 @@ def simplify_concat(expression):
|
|||
|
||||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
elif isinstance(expression, exp.DPipe):
|
||||
return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
|
||||
|
||||
return concat_type(expressions=new_args, **args)
|
||||
|
||||
|
|
|
@ -289,7 +289,7 @@ class Parser(metaclass=_Parser):
|
|||
RESERVED_TOKENS = {
|
||||
*Tokenizer.SINGLE_TOKENS.values(),
|
||||
TokenType.SELECT,
|
||||
}
|
||||
} - {TokenType.IDENTIFIER}
|
||||
|
||||
DB_CREATABLES = {
|
||||
TokenType.DATABASE,
|
||||
|
@ -1109,6 +1109,9 @@ class Parser(metaclass=_Parser):
|
|||
# Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS
|
||||
INTERVAL_SPANS = True
|
||||
|
||||
# Whether a PARTITION clause can follow a table reference
|
||||
SUPPORTS_PARTITION_SELECTION = False
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1764,7 +1767,11 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
return self.expression(exp_class, this=self._parse_field(), **kwargs)
|
||||
field = self._parse_field()
|
||||
if isinstance(field, exp.Identifier) and not field.quoted:
|
||||
field = exp.var(field)
|
||||
|
||||
return self.expression(exp_class, this=field, **kwargs)
|
||||
|
||||
def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
|
||||
properties = []
|
||||
|
@ -2234,7 +2241,11 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.TABLE)
|
||||
is_function = self._match(TokenType.FUNCTION)
|
||||
|
||||
this = self._parse_table(schema=True) if not is_function else self._parse_function()
|
||||
this = (
|
||||
self._parse_table(schema=True, parse_partition=True)
|
||||
if not is_function
|
||||
else self._parse_function()
|
||||
)
|
||||
|
||||
returning = self._parse_returning()
|
||||
|
||||
|
@ -2244,9 +2255,9 @@ class Parser(metaclass=_Parser):
|
|||
hint=hint,
|
||||
is_function=is_function,
|
||||
this=this,
|
||||
stored=self._match_text_seq("STORED") and self._parse_stored(),
|
||||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
|
||||
and self._parse_conjunction(),
|
||||
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
|
||||
|
@ -3098,6 +3109,9 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
table = exp.Identifier(this="*")
|
||||
|
||||
# We bubble up comments from the Identifier to the Table
|
||||
comments = table.pop_comments() if isinstance(table, exp.Expression) else None
|
||||
|
||||
if is_db_reference:
|
||||
catalog = db
|
||||
db = table
|
||||
|
@ -3109,7 +3123,12 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error(f"Expected database name but got {self._curr}")
|
||||
|
||||
return self.expression(
|
||||
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||
exp.Table,
|
||||
comments=comments,
|
||||
this=table,
|
||||
db=db,
|
||||
catalog=catalog,
|
||||
pivots=self._parse_pivots(),
|
||||
)
|
||||
|
||||
def _parse_table(
|
||||
|
@ -3119,6 +3138,7 @@ class Parser(metaclass=_Parser):
|
|||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
parse_partition: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
lateral = self._parse_lateral()
|
||||
if lateral:
|
||||
|
@ -3157,6 +3177,10 @@ class Parser(metaclass=_Parser):
|
|||
# Postgres supports a wildcard (table) suffix operator, which is a no-op in this context
|
||||
self._match_text_seq("*")
|
||||
|
||||
parse_partition = parse_partition or self.SUPPORTS_PARTITION_SELECTION
|
||||
if parse_partition and self._match(TokenType.PARTITION, advance=False):
|
||||
this.set("partition", self._parse_partition())
|
||||
|
||||
if schema:
|
||||
return self._parse_schema(this=this)
|
||||
|
||||
|
@ -4200,7 +4224,11 @@ class Parser(metaclass=_Parser):
|
|||
):
|
||||
this = self._parse_id_var()
|
||||
|
||||
return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this
|
||||
if isinstance(this, exp.Identifier):
|
||||
# We bubble up comments from the Identifier to the Column
|
||||
this = self.expression(exp.Column, comments=this.pop_comments(), this=this)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bracket(this)
|
||||
|
@ -4216,7 +4244,7 @@ class Parser(metaclass=_Parser):
|
|||
elif op and self._curr:
|
||||
field = self._parse_column_reference()
|
||||
else:
|
||||
field = self._parse_field(anonymous_func=True, any_token=True)
|
||||
field = self._parse_field(any_token=True, anonymous_func=True)
|
||||
|
||||
if isinstance(field, exp.Func) and this:
|
||||
# bigquery allows function calls like x.y.count(...)
|
||||
|
@ -4285,7 +4313,7 @@ class Parser(metaclass=_Parser):
|
|||
this = self._parse_subquery(
|
||||
this=self._parse_set_operations(this), parse_alias=False
|
||||
)
|
||||
elif len(expressions) > 1:
|
||||
elif len(expressions) > 1 or self._prev.token_type == TokenType.COMMA:
|
||||
this = self.expression(exp.Tuple, expressions=expressions)
|
||||
else:
|
||||
this = self.expression(exp.Paren, this=this)
|
||||
|
@ -4304,17 +4332,23 @@ class Parser(metaclass=_Parser):
|
|||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
anonymous_func: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
self._parse_primary()
|
||||
or self._parse_function(anonymous=anonymous_func)
|
||||
or self._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
)
|
||||
if anonymous_func:
|
||||
field = (
|
||||
self._parse_function(anonymous=anonymous_func, any_token=any_token)
|
||||
or self._parse_primary()
|
||||
)
|
||||
else:
|
||||
field = self._parse_primary() or self._parse_function(
|
||||
anonymous=anonymous_func, any_token=any_token
|
||||
)
|
||||
return field or self._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
|
||||
def _parse_function(
|
||||
self,
|
||||
functions: t.Optional[t.Dict[str, t.Callable]] = None,
|
||||
anonymous: bool = False,
|
||||
optional_parens: bool = True,
|
||||
any_token: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
# This allows us to also parse {fn <function>} syntax (Snowflake, MySQL support this)
|
||||
# See: https://community.snowflake.com/s/article/SQL-Escape-Sequences
|
||||
|
@ -4328,7 +4362,10 @@ class Parser(metaclass=_Parser):
|
|||
fn_syntax = True
|
||||
|
||||
func = self._parse_function_call(
|
||||
functions=functions, anonymous=anonymous, optional_parens=optional_parens
|
||||
functions=functions,
|
||||
anonymous=anonymous,
|
||||
optional_parens=optional_parens,
|
||||
any_token=any_token,
|
||||
)
|
||||
|
||||
if fn_syntax:
|
||||
|
@ -4341,6 +4378,7 @@ class Parser(metaclass=_Parser):
|
|||
functions: t.Optional[t.Dict[str, t.Callable]] = None,
|
||||
anonymous: bool = False,
|
||||
optional_parens: bool = True,
|
||||
any_token: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if not self._curr:
|
||||
return None
|
||||
|
@ -4362,7 +4400,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return None
|
||||
|
||||
if token_type not in self.FUNC_TOKENS:
|
||||
if any_token:
|
||||
if token_type in self.RESERVED_TOKENS:
|
||||
return None
|
||||
elif token_type not in self.FUNC_TOKENS:
|
||||
return None
|
||||
|
||||
self._advance(2)
|
||||
|
@ -4501,7 +4542,6 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
return this
|
||||
|
||||
|
@ -4510,9 +4550,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_set(self.SELECT_START_TOKENS):
|
||||
self._retreat(index)
|
||||
return this
|
||||
|
||||
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
|
||||
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Schema, this=this, expressions=args)
|
||||
|
||||
|
@ -5378,8 +5416,8 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
over = self._prev.text.upper()
|
||||
|
||||
if comments:
|
||||
func.comments = None # type: ignore
|
||||
if comments and isinstance(func, exp.Expression):
|
||||
func.pop_comments()
|
||||
|
||||
if not self._match(TokenType.L_PAREN):
|
||||
return self.expression(
|
||||
|
@ -5457,7 +5495,7 @@ class Parser(metaclass=_Parser):
|
|||
self, this: t.Optional[exp.Expression], explicit: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
any_token = self._match(TokenType.ALIAS)
|
||||
comments = self._prev_comments
|
||||
comments = self._prev_comments or []
|
||||
|
||||
if explicit and not any_token:
|
||||
return this
|
||||
|
@ -5477,13 +5515,13 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
if alias:
|
||||
comments.extend(alias.pop_comments())
|
||||
this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
|
||||
column = this.this
|
||||
|
||||
# Moves the comment next to the alias in `expr /* comment */ AS alias`
|
||||
if not this.comments and column and column.comments:
|
||||
this.comments = column.comments
|
||||
column.comments = None
|
||||
this.comments = column.pop_comments()
|
||||
|
||||
return this
|
||||
|
||||
|
@ -5492,16 +5530,14 @@ class Parser(metaclass=_Parser):
|
|||
any_token: bool = True,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
identifier = self._parse_identifier()
|
||||
|
||||
if identifier:
|
||||
return identifier
|
||||
|
||||
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
expression = self._parse_identifier()
|
||||
if not expression and (
|
||||
(any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS)
|
||||
):
|
||||
quoted = self._prev.token_type == TokenType.STRING
|
||||
return exp.Identifier(this=self._prev.text, quoted=quoted)
|
||||
expression = self.expression(exp.Identifier, this=self._prev.text, quoted=quoted)
|
||||
|
||||
return None
|
||||
return expression
|
||||
|
||||
def _parse_string(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.STRING_PARSERS):
|
||||
|
|
|
@ -567,11 +567,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"~": TokenType.TILDA,
|
||||
"?": TokenType.PLACEHOLDER,
|
||||
"@": TokenType.PARAMETER,
|
||||
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
|
||||
"'": TokenType.QUOTE,
|
||||
"`": TokenType.IDENTIFIER,
|
||||
'"': TokenType.IDENTIFIER,
|
||||
"#": TokenType.HASH,
|
||||
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
|
||||
"'": TokenType.UNKNOWN,
|
||||
"`": TokenType.UNKNOWN,
|
||||
'"': TokenType.UNKNOWN,
|
||||
}
|
||||
|
||||
BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
|
|
|
@ -93,7 +93,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
Some dialects don't support window functions in the WHERE clause, so we need to include them as
|
||||
projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
|
||||
if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
|
||||
otherwise we won't be able to refer to it in the outer query's WHERE clause.
|
||||
otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
|
||||
newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
|
||||
corresponding expression to avoid creating invalid column references.
|
||||
"""
|
||||
if isinstance(expression, exp.Select) and expression.args.get("qualify"):
|
||||
taken = set(expression.named_selects)
|
||||
|
@ -105,20 +107,31 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
|
||||
qualify_filters = expression.args["qualify"].pop().this
|
||||
expression_by_alias = {
|
||||
select.alias: select.this
|
||||
for select in expression.selects
|
||||
if isinstance(select, exp.Alias)
|
||||
}
|
||||
|
||||
select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
|
||||
for expr in qualify_filters.find_all(select_candidates):
|
||||
if isinstance(expr, exp.Window):
|
||||
for select_candidate in qualify_filters.find_all(select_candidates):
|
||||
if isinstance(select_candidate, exp.Window):
|
||||
if expression_by_alias:
|
||||
for column in select_candidate.find_all(exp.Column):
|
||||
expr = expression_by_alias.get(column.name)
|
||||
if expr:
|
||||
column.replace(expr)
|
||||
|
||||
alias = find_new_name(expression.named_selects, "_w")
|
||||
expression.select(exp.alias_(expr, alias), copy=False)
|
||||
expression.select(exp.alias_(select_candidate, alias), copy=False)
|
||||
column = exp.column(alias)
|
||||
|
||||
if isinstance(expr.parent, exp.Qualify):
|
||||
if isinstance(select_candidate.parent, exp.Qualify):
|
||||
qualify_filters = column
|
||||
else:
|
||||
expr.replace(column)
|
||||
elif expr.name not in expression.named_selects:
|
||||
expression.select(expr.copy(), copy=False)
|
||||
select_candidate.replace(column)
|
||||
elif select_candidate.name not in expression.named_selects:
|
||||
expression.select(select_candidate.copy(), copy=False)
|
||||
|
||||
return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
|
||||
qualify_filters, copy=False
|
||||
|
@ -336,13 +349,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
return _explode_to_unnest
|
||||
|
||||
|
||||
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
|
||||
|
||||
|
||||
def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
|
||||
"""Transforms percentiles by adding a WITHIN GROUP clause to them."""
|
||||
if (
|
||||
isinstance(expression, PERCENTILES)
|
||||
isinstance(expression, exp.PERCENTILES)
|
||||
and not isinstance(expression.parent, exp.WithinGroup)
|
||||
and expression.expression
|
||||
):
|
||||
|
@ -358,7 +368,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
|
|||
"""Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
|
||||
if (
|
||||
isinstance(expression, exp.WithinGroup)
|
||||
and isinstance(expression.this, PERCENTILES)
|
||||
and isinstance(expression.this, exp.PERCENTILES)
|
||||
and isinstance(expression.expression, exp.Order)
|
||||
):
|
||||
quantile = expression.this.this
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue