Merging upstream version 16.7.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
331a760a3d
commit
088f137198
75 changed files with 33866 additions and 31988 deletions
|
@ -1119,7 +1119,7 @@ def map_entries(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def map_from_entries(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
|
||||
return Column.invoke_expression_over_column(col, expression.MapFromEntries)
|
||||
|
||||
|
||||
def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
|
@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.helper import seq_get, split_num_words
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
|
@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
# https://issuetracker.google.com/issues/162294746
|
||||
# workaround for bigquery bug when grouping by an expression and then ordering
|
||||
# WITH x AS (SELECT 1 y)
|
||||
# SELECT y + 1 z
|
||||
# FROM x
|
||||
# GROUP BY x + 1
|
||||
# ORDER by z
|
||||
def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Select):
|
||||
group = expression.args.get("group")
|
||||
order = expression.args.get("order")
|
||||
|
||||
if group and order:
|
||||
aliases = {
|
||||
select.this: select.args["alias"]
|
||||
for select in expression.selects
|
||||
if isinstance(select, exp.Alias)
|
||||
}
|
||||
|
||||
for e in group.expressions:
|
||||
alias = aliases.get(e)
|
||||
|
||||
if alias:
|
||||
e.replace(exp.column(alias))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
|
||||
"""BigQuery doesn't allow column names when defining a CTE, so we try to push them down."""
|
||||
if isinstance(expression, exp.CTE) and expression.alias_column_names:
|
||||
cte_query = expression.this
|
||||
|
||||
if cte_query.is_star:
|
||||
logger.warning(
|
||||
"Can't push down CTE column names for star queries. Run the query through"
|
||||
" the optimizer or use 'qualify' to expand the star projections first."
|
||||
)
|
||||
return expression
|
||||
|
||||
column_names = expression.alias_column_names
|
||||
expression.args["alias"].set("columns", None)
|
||||
|
||||
for name, select in zip(column_names, cte_query.selects):
|
||||
to_replace = select
|
||||
|
||||
if isinstance(select, exp.Alias):
|
||||
select = select.this
|
||||
|
||||
# Inner aliases are shadowed by the CTE column names
|
||||
to_replace.replace(exp.alias_(select, name))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
# bigquery udfs are case sensitive
|
||||
NORMALIZE_FUNCTIONS = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
}
|
||||
|
@ -135,12 +196,16 @@ class BigQuery(Dialect):
|
|||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
# The following check is essentially a heuristic to detect tables based on whether or
|
||||
# not they're qualified.
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
|
||||
and not expression.meta.get("is_table")
|
||||
):
|
||||
expression.set("this", expression.this.lower())
|
||||
if isinstance(expression, exp.Identifier):
|
||||
parent = expression.parent
|
||||
|
||||
while isinstance(parent, exp.Dot):
|
||||
parent = parent.parent
|
||||
|
||||
if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get(
|
||||
"is_table"
|
||||
):
|
||||
expression.set("this", expression.this.lower())
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -298,10 +363,8 @@ class BigQuery(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
|
||||
),
|
||||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
|
@ -325,7 +388,12 @@ class BigQuery(Dialect):
|
|||
),
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest, transforms.eliminate_distinct_on]
|
||||
[
|
||||
transforms.explode_to_unnest,
|
||||
_unqualify_unnest,
|
||||
transforms.eliminate_distinct_on,
|
||||
_alias_ordered_group,
|
||||
]
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
|
@ -334,7 +402,6 @@ class BigQuery(Dialect):
|
|||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
|
||||
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
@ -378,7 +445,121 @@ class BigQuery(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
|
||||
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
|
||||
RESERVED_KEYWORDS = {
|
||||
*generator.Generator.RESERVED_KEYWORDS,
|
||||
"all",
|
||||
"and",
|
||||
"any",
|
||||
"array",
|
||||
"as",
|
||||
"asc",
|
||||
"assert_rows_modified",
|
||||
"at",
|
||||
"between",
|
||||
"by",
|
||||
"case",
|
||||
"cast",
|
||||
"collate",
|
||||
"contains",
|
||||
"create",
|
||||
"cross",
|
||||
"cube",
|
||||
"current",
|
||||
"default",
|
||||
"define",
|
||||
"desc",
|
||||
"distinct",
|
||||
"else",
|
||||
"end",
|
||||
"enum",
|
||||
"escape",
|
||||
"except",
|
||||
"exclude",
|
||||
"exists",
|
||||
"extract",
|
||||
"false",
|
||||
"fetch",
|
||||
"following",
|
||||
"for",
|
||||
"from",
|
||||
"full",
|
||||
"group",
|
||||
"grouping",
|
||||
"groups",
|
||||
"hash",
|
||||
"having",
|
||||
"if",
|
||||
"ignore",
|
||||
"in",
|
||||
"inner",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"is",
|
||||
"join",
|
||||
"lateral",
|
||||
"left",
|
||||
"like",
|
||||
"limit",
|
||||
"lookup",
|
||||
"merge",
|
||||
"natural",
|
||||
"new",
|
||||
"no",
|
||||
"not",
|
||||
"null",
|
||||
"nulls",
|
||||
"of",
|
||||
"on",
|
||||
"or",
|
||||
"order",
|
||||
"outer",
|
||||
"over",
|
||||
"partition",
|
||||
"preceding",
|
||||
"proto",
|
||||
"qualify",
|
||||
"range",
|
||||
"recursive",
|
||||
"respect",
|
||||
"right",
|
||||
"rollup",
|
||||
"rows",
|
||||
"select",
|
||||
"set",
|
||||
"some",
|
||||
"struct",
|
||||
"tablesample",
|
||||
"then",
|
||||
"to",
|
||||
"treat",
|
||||
"true",
|
||||
"unbounded",
|
||||
"union",
|
||||
"unnest",
|
||||
"using",
|
||||
"when",
|
||||
"where",
|
||||
"window",
|
||||
"with",
|
||||
"within",
|
||||
}
|
||||
|
||||
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
|
||||
if not isinstance(expression.parent, exp.Cast):
|
||||
return self.func(
|
||||
"TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone"))
|
||||
)
|
||||
return super().attimezone_sql(expression)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="SAFE_")
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if expression.alias_column_names:
|
||||
self.unsupported("Column names in CTE definition are not supported.")
|
||||
return super().cte_sql(expression)
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
|
|
|
@ -388,6 +388,11 @@ def no_comment_column_constraint_sql(
|
|||
return ""
|
||||
|
||||
|
||||
def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
|
||||
self.unsupported("MAP_FROM_ENTRIES unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
substr = self.sql(expression, "substr")
|
||||
|
|
|
@ -132,6 +132,10 @@ class MySQL(Dialect):
|
|||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"START": TokenType.BEGIN,
|
||||
"SIGNED": TokenType.BIGINT,
|
||||
"SIGNED INTEGER": TokenType.BIGINT,
|
||||
"UNSIGNED": TokenType.UBIGINT,
|
||||
"UNSIGNED INTEGER": TokenType.UBIGINT,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
"_BIG5": TokenType.INTRODUCER,
|
||||
|
@ -441,6 +445,17 @@ class MySQL(Dialect):
|
|||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
|
||||
if expression.to.this == exp.DataType.Type.BIGINT:
|
||||
to = "SIGNED"
|
||||
elif expression.to.this == exp.DataType.Type.UBIGINT:
|
||||
to = "UNSIGNED"
|
||||
else:
|
||||
return super().cast_sql(expression)
|
||||
|
||||
return f"CAST({self.sql(expression, 'this')} AS {to})"
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
this = f" {expression.name}"
|
||||
full = " FULL" if expression.args.get("full") else ""
|
||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_map_from_entries_sql,
|
||||
no_paren_current_date_sql,
|
||||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
|
@ -346,6 +347,7 @@ class Postgres(Dialect):
|
|||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MapFromEntries: no_map_from_entries_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
|
@ -378,3 +380,11 @@ class Postgres(Dialect):
|
|||
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
|
||||
if isinstance(expression.this, exp.Array):
|
||||
expression = expression.copy()
|
||||
expression.set("this", exp.paren(expression.this, copy=False))
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
|
|
|
@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.errors import UnsupportedError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import apply_index_offset, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression:
|
|||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
def _parse_element_at(args: t.List) -> exp.SafeBracket:
|
||||
this = seq_get(args, 0)
|
||||
index = seq_get(args, 1)
|
||||
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
|
||||
return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
|
||||
|
||||
|
||||
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Table):
|
||||
if isinstance(expression.this, exp.GenerateSeries):
|
||||
|
@ -201,6 +208,7 @@ class Presto(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"ELEMENT_AT": _parse_element_at,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"FROM_UNIXTIME": _from_unixtime,
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
|
@ -285,6 +293,9 @@ class Presto(Dialect):
|
|||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeBracket: lambda self, e: self.func(
|
||||
"ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
|
||||
),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
|
|
@ -41,8 +41,6 @@ class Redshift(Postgres):
|
|||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
CONVERT_TYPE_FIRST = True
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -58,6 +56,12 @@ class Redshift(Postgres):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_bitwise()
|
||||
return self.expression(exp.TryCast, this=this, to=to)
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
|
|
|
@ -258,14 +258,29 @@ class Snowflake(Dialect):
|
|||
|
||||
ALTER_PARSERS = {
|
||||
**parser.Parser.ALTER_PARSERS,
|
||||
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
|
||||
"SET": lambda self: self._parse_alter_table_set_tag(),
|
||||
"SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")),
|
||||
"UNSET": lambda self: self.expression(
|
||||
exp.Set,
|
||||
tag=self._match_text_seq("TAG"),
|
||||
expressions=self._parse_csv(self._parse_id_var),
|
||||
unset=True,
|
||||
),
|
||||
}
|
||||
|
||||
def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
|
||||
self._match_text_seq("TAG")
|
||||
parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
|
||||
return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
|
||||
def _parse_id_var(
|
||||
self,
|
||||
any_token: bool = True,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("IDENTIFIER", "("):
|
||||
identifier = (
|
||||
super()._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
or self._parse_string()
|
||||
)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier])
|
||||
|
||||
return super()._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
|
@ -380,10 +395,6 @@ class Snowflake(Dialect):
|
|||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
||||
def settag_sql(self, expression: exp.SetTag) -> str:
|
||||
action = "UNSET" if expression.args.get("unset") else "SET"
|
||||
return f"{action} TAG {self.expressions(expression)}"
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
# Default to table if kind is unknown
|
||||
kind_value = expression.args.get("kind") or "TABLE"
|
||||
|
|
|
@ -43,6 +43,7 @@ class Spark(Spark2):
|
|||
class Generator(Spark2.Generator):
|
||||
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
|
|
|
@ -231,14 +231,14 @@ class Spark2(Hive):
|
|||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.is_type("json"):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression)
|
||||
return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
return super().columndef_sql(
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -133,7 +135,7 @@ class SQLite(Dialect):
|
|||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if expression.is_type("date"):
|
||||
return self.func("DATE", expression.this)
|
||||
|
||||
|
|
|
@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
|
|||
|
||||
|
||||
class TSQL(Dialect):
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
||||
|
|
|
@ -63,11 +63,9 @@ class Context:
|
|||
reader = table[i]
|
||||
yield reader, self
|
||||
|
||||
def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
|
||||
def table_iter(self, table: str) -> TableIter:
|
||||
self.env["scope"] = self.row_readers
|
||||
|
||||
for reader in self.tables[table]:
|
||||
yield reader, self
|
||||
return iter(self.tables[table])
|
||||
|
||||
def filter(self, condition) -> None:
|
||||
rows = [reader.row for reader, _ in self if self.eval(condition)]
|
||||
|
|
|
@ -276,11 +276,9 @@ class PythonExecutor:
|
|||
end = 1
|
||||
length = len(context.table)
|
||||
table = self.table(list(step.group) + step.aggregations)
|
||||
condition = self.generate(step.condition)
|
||||
|
||||
def add_row():
|
||||
if not condition or context.eval(condition):
|
||||
table.append(group + context.eval_tuple(aggregations))
|
||||
table.append(group + context.eval_tuple(aggregations))
|
||||
|
||||
if length:
|
||||
for i in range(length):
|
||||
|
@ -304,7 +302,7 @@ class PythonExecutor:
|
|||
|
||||
context = self.context({step.name: table, **{name: table for name in context.tables}})
|
||||
|
||||
if step.projections:
|
||||
if step.projections or step.condition:
|
||||
return self.scan(step, context)
|
||||
return context
|
||||
|
||||
|
|
|
@ -1013,7 +1013,7 @@ class Pragma(Expression):
|
|||
|
||||
|
||||
class Set(Expression):
|
||||
arg_types = {"expressions": False}
|
||||
arg_types = {"expressions": False, "unset": False, "tag": False}
|
||||
|
||||
|
||||
class SetItem(Expression):
|
||||
|
@ -1168,10 +1168,6 @@ class RenameTable(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class SetTag(Expression):
|
||||
arg_types = {"expressions": True, "unset": False}
|
||||
|
||||
|
||||
class Comment(Expression):
|
||||
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
|
||||
|
||||
|
@ -1934,6 +1930,11 @@ class LanguageProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# spark ddl
|
||||
class ClusteredByProperty(Property):
|
||||
arg_types = {"expressions": True, "sorted_by": False, "buckets": True}
|
||||
|
||||
|
||||
class DictProperty(Property):
|
||||
arg_types = {"this": True, "kind": True, "settings": False}
|
||||
|
||||
|
@ -2074,6 +2075,7 @@ class Properties(Expression):
|
|||
"ALGORITHM": AlgorithmProperty,
|
||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||
"CHARACTER SET": CharacterSetProperty,
|
||||
"CLUSTERED_BY": ClusteredByProperty,
|
||||
"COLLATE": CollateProperty,
|
||||
"COMMENT": SchemaCommentProperty,
|
||||
"DEFINER": DefinerProperty,
|
||||
|
@ -2280,6 +2282,12 @@ class Table(Expression):
|
|||
"system_time": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
if isinstance(self.this, Func):
|
||||
return ""
|
||||
return self.this.name
|
||||
|
||||
@property
|
||||
def db(self) -> str:
|
||||
return self.text("db")
|
||||
|
@ -3716,6 +3724,10 @@ class Bracket(Condition):
|
|||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
||||
class SafeBracket(Bracket):
|
||||
"""Represents array lookup where OOB index yields NULL instead of causing a failure."""
|
||||
|
||||
|
||||
class Distinct(Expression):
|
||||
arg_types = {"expressions": False, "on": False}
|
||||
|
||||
|
@ -3934,7 +3946,7 @@ class Case(Func):
|
|||
|
||||
|
||||
class Cast(Func):
|
||||
arg_types = {"this": True, "to": True}
|
||||
arg_types = {"this": True, "to": True, "format": False}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -4292,6 +4304,10 @@ class Map(Func):
|
|||
arg_types = {"keys": False, "values": False}
|
||||
|
||||
|
||||
class MapFromEntries(Func):
|
||||
pass
|
||||
|
||||
|
||||
class StarMap(Func):
|
||||
pass
|
||||
|
||||
|
|
|
@ -188,6 +188,7 @@ class Generator:
|
|||
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -1408,7 +1409,8 @@ class Generator:
|
|||
expressions = (
|
||||
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
|
||||
)
|
||||
return f"SET{expressions}"
|
||||
tag = " TAG" if expression.args.get("tag") else ""
|
||||
return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}"
|
||||
|
||||
def pragma_sql(self, expression: exp.Pragma) -> str:
|
||||
return f"PRAGMA {self.sql(expression, 'this')}"
|
||||
|
@ -1749,6 +1751,9 @@ class Generator:
|
|||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
||||
def safebracket_sql(self, expression: exp.SafeBracket) -> str:
|
||||
return self.bracket_sql(expression)
|
||||
|
||||
def all_sql(self, expression: exp.All) -> str:
|
||||
return f"ALL {self.wrap(expression)}"
|
||||
|
||||
|
@ -2000,8 +2005,10 @@ class Generator:
|
|||
def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
|
||||
return self.binary(expression, "^")
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
format_sql = self.sql(expression, "format")
|
||||
format_sql = f" FORMAT {format_sql}" if format_sql else ""
|
||||
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
|
||||
|
||||
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
|
||||
zone = self.sql(expression, "this")
|
||||
|
@ -2227,7 +2234,7 @@ class Generator:
|
|||
return self.binary(expression, "-")
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
|
||||
return self.cast_sql(expression, safe_prefix="TRY_")
|
||||
|
||||
def use_sql(self, expression: exp.Use) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
|
@ -2409,6 +2416,13 @@ class Generator:
|
|||
def oncluster_sql(self, expression: exp.OnCluster) -> str:
|
||||
return ""
|
||||
|
||||
def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
|
||||
expressions = self.expressions(expression, key="expressions", flat=True)
|
||||
sorted_by = self.expressions(expression, key="sorted_by", flat=True)
|
||||
sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else ""
|
||||
buckets = self.sql(expression, "buckets")
|
||||
return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -60,8 +60,8 @@ def qualify(
|
|||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -56,13 +56,13 @@ def qualify_columns(
|
|||
if not isinstance(scope.expression, exp.UDTF):
|
||||
_expand_stars(scope, resolver, using_column_tables)
|
||||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
_expand_group_by(scope)
|
||||
_expand_order_by(scope, resolver)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def validate_qualify_columns(expression):
|
||||
def validate_qualify_columns(expression: E) -> E:
|
||||
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
||||
unqualified_columns = []
|
||||
for scope in traverse_scope(expression):
|
||||
|
@ -79,7 +79,7 @@ def validate_qualify_columns(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def _pop_table_column_aliases(derived_tables):
|
||||
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
|
||||
"""
|
||||
Remove table column aliases.
|
||||
|
||||
|
@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables):
|
|||
table_alias.args.pop("columns", None)
|
||||
|
||||
|
||||
def _expand_using(scope, resolver):
|
||||
def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.alias_or_name for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables = {}
|
||||
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
||||
|
||||
for join in joins:
|
||||
using = join.args.get("using")
|
||||
|
@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||
|
||||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
|
||||
):
|
||||
def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
|
||||
if not node:
|
||||
return
|
||||
|
||||
for column, *_ in walk_in_scope(node):
|
||||
if not isinstance(column, exp.Column):
|
||||
continue
|
||||
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
|
||||
if table and column.find_ancestor(exp.AggFunc):
|
||||
table = resolver.get_table(column.name) if resolve_table and not column.table else None
|
||||
alias_expr = alias_to_expression.get(column.name)
|
||||
double_agg = (
|
||||
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
|
||||
if alias_expr
|
||||
else False
|
||||
)
|
||||
|
||||
if table and (not alias_expr or double_agg):
|
||||
column.set("table", table)
|
||||
elif expand and not column.table and column.name in alias_to_expression:
|
||||
column.replace(alias_to_expression[column.name].copy())
|
||||
elif not column.table and alias_expr and not double_agg:
|
||||
column.replace(alias_expr.copy())
|
||||
|
||||
for projection in scope.selects:
|
||||
replace_columns(projection)
|
||||
|
@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"))
|
||||
replace_columns(expression.args.get("having"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_agg=True)
|
||||
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
replace_columns(expression.args.get("qualify"), resolve_table=True)
|
||||
scope.clear_cache()
|
||||
|
||||
|
||||
def _expand_group_by(scope, resolver):
|
||||
group = scope.expression.args.get("group")
|
||||
def _expand_group_by(scope: Scope):
|
||||
expression = scope.expression
|
||||
group = expression.args.get("group")
|
||||
if not group:
|
||||
return
|
||||
|
||||
group.set("expressions", _expand_positional_references(scope, group.expressions))
|
||||
scope.expression.set("group", group)
|
||||
expression.set("group", group)
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
groups = set(group.expressions)
|
||||
group.meta["final"] = True
|
||||
|
||||
for e in expression.selects:
|
||||
for node, *_ in e.walk():
|
||||
if node in groups:
|
||||
e.meta["final"] = True
|
||||
break
|
||||
|
||||
having = expression.args.get("having")
|
||||
if having:
|
||||
for node, *_ in having.walk():
|
||||
if node in groups:
|
||||
having.meta["final"] = True
|
||||
break
|
||||
|
||||
|
||||
def _expand_order_by(scope):
|
||||
def _expand_order_by(scope: Scope, resolver: Resolver):
|
||||
order = scope.expression.args.get("order")
|
||||
if not order:
|
||||
return
|
||||
|
@ -220,10 +244,21 @@ def _expand_order_by(scope):
|
|||
ordereds,
|
||||
_expand_positional_references(scope, (o.this for o in ordereds)),
|
||||
):
|
||||
for agg in ordered.find_all(exp.AggFunc):
|
||||
for col in agg.find_all(exp.Column):
|
||||
if not col.table:
|
||||
col.set("table", resolver.get_table(col.name))
|
||||
|
||||
ordered.set("this", new_expression)
|
||||
|
||||
if scope.expression.args.get("group"):
|
||||
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
|
||||
|
||||
def _expand_positional_references(scope, expressions):
|
||||
for ordered in ordereds:
|
||||
ordered.set("this", selects.get(ordered.this, ordered.this))
|
||||
|
||||
|
||||
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
|
||||
new_nodes = []
|
||||
for node in expressions:
|
||||
if node.is_int:
|
||||
|
@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions):
|
|||
return new_nodes
|
||||
|
||||
|
||||
def _qualify_columns(scope, resolver):
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
column_table = column.table
|
||||
|
@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver):
|
|||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_stars(scope, resolver, using_column_tables):
|
||||
def _expand_stars(
|
||||
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
|
||||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
new_selections = []
|
||||
except_columns = {}
|
||||
replace_columns = {}
|
||||
except_columns: t.Dict[int, t.Set[str]] = {}
|
||||
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
coalesced_columns = set()
|
||||
|
||||
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
|
||||
pivot_columns = None
|
||||
pivot_output_columns = None
|
||||
pivot = seq_get(scope.pivots, 0)
|
||||
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
|
||||
|
||||
has_pivoted_source = pivot and not pivot.args.get("unpivot")
|
||||
if has_pivoted_source:
|
||||
if pivot and has_pivoted_source:
|
||||
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
|
||||
|
||||
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
|
||||
|
@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
|
||||
columns = resolver.get_source_columns(table, only_visible=True)
|
||||
|
||||
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
|
||||
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
|
||||
if resolver.schema.dialect == "bigquery":
|
||||
columns = [
|
||||
name
|
||||
for name in columns
|
||||
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
|
||||
]
|
||||
|
||||
if columns and "*" not in columns:
|
||||
if has_pivoted_source:
|
||||
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
|
||||
implicit_columns = [col for col in columns if col not in pivot_columns]
|
||||
new_selections.extend(
|
||||
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
|
||||
|
@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables):
|
|||
scope.expression.set("expressions", new_selections)
|
||||
|
||||
|
||||
def _add_except_columns(expression, tables, except_columns):
|
||||
def _add_except_columns(
|
||||
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
|
||||
) -> None:
|
||||
except_ = expression.args.get("except")
|
||||
|
||||
if not except_:
|
||||
|
@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns):
|
|||
except_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _add_replace_columns(expression, tables, replace_columns):
|
||||
def _add_replace_columns(
|
||||
expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
|
||||
) -> None:
|
||||
replace = expression.args.get("replace")
|
||||
|
||||
if not replace:
|
||||
|
@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns):
|
|||
replace_columns[id(table)] = columns
|
||||
|
||||
|
||||
def _qualify_outputs(scope):
|
||||
def _qualify_outputs(scope: Scope):
|
||||
"""Ensure all output columns are aliased"""
|
||||
new_selections = []
|
||||
|
||||
|
@ -429,7 +479,7 @@ class Resolver:
|
|||
This is a class so we can lazily load some things and easily share them across functions.
|
||||
"""
|
||||
|
||||
def __init__(self, scope, schema, infer_schema: bool = True):
|
||||
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
|
||||
self.scope = scope
|
||||
self.schema = schema
|
||||
self._source_columns = None
|
||||
|
|
|
@ -28,6 +28,8 @@ def simplify(expression):
|
|||
generate = cached_generator()
|
||||
|
||||
def _simplify(expression, root=True):
|
||||
if expression.meta.get("final"):
|
||||
return expression
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, generate, root)
|
||||
|
|
|
@ -585,6 +585,7 @@ class Parser(metaclass=_Parser):
|
|||
"CHARACTER SET": lambda self: self._parse_character_set(),
|
||||
"CHECKSUM": lambda self: self._parse_checksum(),
|
||||
"CLUSTER BY": lambda self: self._parse_cluster(),
|
||||
"CLUSTERED": lambda self: self._parse_clustered_by(),
|
||||
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
|
||||
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
|
||||
"COPY": lambda self: self._parse_copy_property(),
|
||||
|
@ -794,8 +795,6 @@ class Parser(metaclass=_Parser):
|
|||
# A NULL arg in CONCAT yields NULL by default
|
||||
CONCAT_NULL_OUTPUTS_STRING = False
|
||||
|
||||
CONVERT_TYPE_FIRST = False
|
||||
|
||||
PREFIXED_PIVOT_COLUMNS = False
|
||||
IDENTIFY_PIVOT_STRINGS = False
|
||||
|
||||
|
@ -1426,9 +1425,34 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
|
||||
|
||||
def _parse_cluster(self) -> t.Optional[exp.Cluster]:
|
||||
def _parse_cluster(self) -> exp.Cluster:
|
||||
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
|
||||
|
||||
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
|
||||
self._match_text_seq("BY")
|
||||
|
||||
self._match_l_paren()
|
||||
expressions = self._parse_csv(self._parse_column)
|
||||
self._match_r_paren()
|
||||
|
||||
if self._match_text_seq("SORTED", "BY"):
|
||||
self._match_l_paren()
|
||||
sorted_by = self._parse_csv(self._parse_ordered)
|
||||
self._match_r_paren()
|
||||
else:
|
||||
sorted_by = None
|
||||
|
||||
self._match(TokenType.INTO)
|
||||
buckets = self._parse_number()
|
||||
self._match_text_seq("BUCKETS")
|
||||
|
||||
return self.expression(
|
||||
exp.ClusteredByProperty,
|
||||
expressions=expressions,
|
||||
sorted_by=sorted_by,
|
||||
buckets=buckets,
|
||||
)
|
||||
|
||||
def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
|
||||
if not self._match_text_seq("GRANTS"):
|
||||
self._retreat(self._index - 1)
|
||||
|
@ -2863,7 +2887,11 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.INTERVAL):
|
||||
return None
|
||||
|
||||
this = self._parse_primary() or self._parse_term()
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
this = self._parse_primary()
|
||||
else:
|
||||
this = self._parse_term()
|
||||
|
||||
unit = self._parse_function() or self._parse_var()
|
||||
|
||||
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
|
||||
|
@ -3661,6 +3689,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
self.raise_error("Expected AS after CAST")
|
||||
|
||||
fmt = None
|
||||
to = self._parse_types()
|
||||
|
||||
if not to:
|
||||
|
@ -3668,22 +3697,23 @@ class Parser(metaclass=_Parser):
|
|||
elif to.this == exp.DataType.Type.CHAR:
|
||||
if self._match(TokenType.CHARACTER_SET):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
|
||||
elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT):
|
||||
fmt = self._parse_string()
|
||||
elif self._match(TokenType.FORMAT):
|
||||
fmt = self._parse_at_time_zone(self._parse_string())
|
||||
|
||||
return self.expression(
|
||||
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
|
||||
this=this,
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
fmt.this if fmt else "",
|
||||
self.FORMAT_MAPPING or self.TIME_MAPPING,
|
||||
self.FORMAT_TRIE or self.TIME_TRIE,
|
||||
)
|
||||
),
|
||||
)
|
||||
if to.this in exp.DataType.TEMPORAL_TYPES:
|
||||
return self.expression(
|
||||
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
|
||||
this=this,
|
||||
format=exp.Literal.string(
|
||||
format_time(
|
||||
fmt.this if fmt else "",
|
||||
self.FORMAT_MAPPING or self.TIME_MAPPING,
|
||||
self.FORMAT_TRIE or self.TIME_TRIE,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
|
||||
|
||||
def _parse_concat(self) -> t.Optional[exp.Expression]:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
|
@ -3704,20 +3734,23 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_string_agg(self) -> exp.Expression:
|
||||
expression: t.Optional[exp.Expression]
|
||||
|
||||
if self._match(TokenType.DISTINCT):
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
|
||||
args: t.List[t.Optional[exp.Expression]] = [
|
||||
self.expression(exp.Distinct, expressions=[self._parse_conjunction()])
|
||||
]
|
||||
if self._match(TokenType.COMMA):
|
||||
args.extend(self._parse_csv(self._parse_conjunction))
|
||||
else:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
expression = seq_get(args, 0)
|
||||
|
||||
index = self._index
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
|
||||
order = self._parse_order(this=expression)
|
||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||
return self.expression(
|
||||
exp.GroupConcat,
|
||||
this=seq_get(args, 0),
|
||||
separator=self._parse_order(this=seq_get(args, 1)),
|
||||
)
|
||||
|
||||
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
|
||||
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
|
||||
|
@ -3727,24 +3760,21 @@ class Parser(metaclass=_Parser):
|
|||
return self.validate_expression(exp.GroupConcat.from_arg_list(args), args)
|
||||
|
||||
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
|
||||
order = self._parse_order(this=expression)
|
||||
order = self._parse_order(this=seq_get(args, 0))
|
||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||
|
||||
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||
to: t.Optional[exp.Expression]
|
||||
this = self._parse_bitwise()
|
||||
|
||||
if self._match(TokenType.USING):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var())
|
||||
to: t.Optional[exp.Expression] = self.expression(
|
||||
exp.CharacterSet, this=self._parse_var()
|
||||
)
|
||||
elif self._match(TokenType.COMMA):
|
||||
to = self._parse_bitwise()
|
||||
to = self._parse_types()
|
||||
else:
|
||||
to = None
|
||||
|
||||
# Swap the argument order if needed to produce the correct AST
|
||||
if self.CONVERT_TYPE_FIRST:
|
||||
this, to = to, this
|
||||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
|
||||
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
|
||||
|
@ -4394,8 +4424,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._next:
|
||||
self._advance()
|
||||
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
|
||||
|
||||
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
|
||||
if parser:
|
||||
actions = ensure_list(parser(self))
|
||||
|
||||
|
@ -4516,9 +4546,11 @@ class Parser(metaclass=_Parser):
|
|||
parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE)
|
||||
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
|
||||
|
||||
def _parse_set(self) -> exp.Set | exp.Command:
|
||||
def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command:
|
||||
index = self._index
|
||||
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||
set_ = self.expression(
|
||||
exp.Set, expressions=self._parse_csv(self._parse_set_item), unset=unset, tag=tag
|
||||
)
|
||||
|
||||
if self._curr:
|
||||
self._retreat(index)
|
||||
|
@ -4683,12 +4715,8 @@ class Parser(metaclass=_Parser):
|
|||
exp.replace_children(this, self._replace_columns_with_dots)
|
||||
table = this.args.get("table")
|
||||
this = (
|
||||
self.expression(exp.Dot, this=table, expression=this.this)
|
||||
if table
|
||||
else self.expression(exp.Var, this=this.name)
|
||||
self.expression(exp.Dot, this=table, expression=this.this) if table else this.this
|
||||
)
|
||||
elif isinstance(this, exp.Identifier):
|
||||
this = self.expression(exp.Var, this=this.name)
|
||||
|
||||
return this
|
||||
|
||||
|
|
|
@ -91,6 +91,7 @@ class Step:
|
|||
A Step DAG corresponding to `expression`.
|
||||
"""
|
||||
ctes = ctes or {}
|
||||
expression = expression.unnest()
|
||||
with_ = expression.args.get("with")
|
||||
|
||||
# CTEs break the mold of scope and introduce themselves to all in the context.
|
||||
|
@ -120,22 +121,25 @@ class Step:
|
|||
|
||||
projections = [] # final selects in this chain of steps representing a select
|
||||
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
|
||||
aggregations = []
|
||||
aggregations = set()
|
||||
next_operand_name = name_sequence("_a_")
|
||||
|
||||
def extract_agg_operands(expression):
|
||||
for agg in expression.find_all(exp.AggFunc):
|
||||
agg_funcs = tuple(expression.find_all(exp.AggFunc))
|
||||
if agg_funcs:
|
||||
aggregations.add(expression)
|
||||
for agg in agg_funcs:
|
||||
for operand in agg.unnest_operands():
|
||||
if isinstance(operand, exp.Column):
|
||||
continue
|
||||
if operand not in operands:
|
||||
operands[operand] = next_operand_name()
|
||||
operand.replace(exp.column(operands[operand], quoted=True))
|
||||
return bool(agg_funcs)
|
||||
|
||||
for e in expression.expressions:
|
||||
if e.find(exp.AggFunc):
|
||||
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
|
||||
aggregations.append(e)
|
||||
extract_agg_operands(e)
|
||||
else:
|
||||
projections.append(e)
|
||||
|
@ -155,22 +159,38 @@ class Step:
|
|||
having = expression.args.get("having")
|
||||
|
||||
if having:
|
||||
extract_agg_operands(having)
|
||||
aggregate.condition = having.this
|
||||
if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
|
||||
aggregate.condition = exp.column("_h", step.name, quoted=True)
|
||||
else:
|
||||
aggregate.condition = having.this
|
||||
|
||||
aggregate.operands = tuple(
|
||||
alias(operand, alias_) for operand, alias_ in operands.items()
|
||||
)
|
||||
aggregate.aggregations = aggregations
|
||||
aggregate.aggregations = list(aggregations)
|
||||
|
||||
# give aggregates names and replace projections with references to them
|
||||
aggregate.group = {
|
||||
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
|
||||
}
|
||||
|
||||
intermediate: t.Dict[str | exp.Expression, str] = {}
|
||||
for k, v in aggregate.group.items():
|
||||
intermediate[v] = k
|
||||
if isinstance(v, exp.Column):
|
||||
intermediate[v.alias_or_name] = k
|
||||
|
||||
for projection in projections:
|
||||
for i, e in aggregate.group.items():
|
||||
for child, *_ in projection.walk():
|
||||
if child == e:
|
||||
child.replace(exp.column(i, step.name))
|
||||
for node, *_ in projection.walk():
|
||||
name = intermediate.get(node)
|
||||
if name:
|
||||
node.replace(exp.column(name, step.name))
|
||||
if aggregate.condition:
|
||||
for node, *_ in aggregate.condition.walk():
|
||||
name = intermediate.get(node) or intermediate.get(node.name)
|
||||
if name:
|
||||
node.replace(exp.column(name, step.name))
|
||||
|
||||
aggregate.add_dependency(step)
|
||||
step = aggregate
|
||||
|
||||
|
|
|
@ -159,10 +159,11 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
if isinstance(expression, exp.Select):
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
|
||||
taken_select_names = set(expression.named_selects)
|
||||
scope = build_scope(expression)
|
||||
if not scope:
|
||||
return expression
|
||||
|
||||
taken_select_names = set(expression.named_selects)
|
||||
taken_source_names = set(scope.selected_sources)
|
||||
|
||||
for select in expression.selects:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue