1
0
Fork 0

Merging upstream version 16.7.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:21:40 +01:00
parent 331a760a3d
commit 088f137198
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
75 changed files with 33866 additions and 31988 deletions

View file

@ -1119,7 +1119,7 @@ def map_entries(col: ColumnOrName) -> Column:
def map_from_entries(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "MAP_FROM_ENTRIES")
return Column.invoke_expression_over_column(col, expression.MapFromEntries)
def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import re
import typing as t
@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
def _date_add_sql(
data_type: str, kind: str
@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression
# https://issuetracker.google.com/issues/162294746
# workaround for bigquery bug when grouping by an expression and then ordering
# WITH x AS (SELECT 1 y)
# SELECT y + 1 z
# FROM x
# GROUP BY x + 1
# ORDER by z
def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
group = expression.args.get("group")
order = expression.args.get("order")
if group and order:
aliases = {
select.this: select.args["alias"]
for select in expression.selects
if isinstance(select, exp.Alias)
}
for e in group.expressions:
alias = aliases.get(e)
if alias:
e.replace(exp.column(alias))
return expression
def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
"""BigQuery doesn't allow column names when defining a CTE, so we try to push them down."""
if isinstance(expression, exp.CTE) and expression.alias_column_names:
cte_query = expression.this
if cte_query.is_star:
logger.warning(
"Can't push down CTE column names for star queries. Run the query through"
" the optimizer or use 'qualify' to expand the star projections first."
)
return expression
column_names = expression.alias_column_names
expression.args["alias"].set("columns", None)
for name, select in zip(column_names, cte_query.selects):
to_replace = select
if isinstance(select, exp.Alias):
select = select.this
# Inner aliases are shadowed by the CTE column names
to_replace.replace(exp.alias_(select, name))
return expression
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
# bigquery udfs are case sensitive
NORMALIZE_FUNCTIONS = False
TIME_MAPPING = {
"%D": "%m/%d/%y",
}
@ -135,12 +196,16 @@ class BigQuery(Dialect):
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
# The following check is essentially a heuristic to detect tables based on whether or
# not they're qualified.
if (
isinstance(expression, exp.Identifier)
and not (isinstance(expression.parent, exp.Table) and expression.parent.db)
and not expression.meta.get("is_table")
):
expression.set("this", expression.this.lower())
if isinstance(expression, exp.Identifier):
parent = expression.parent
while isinstance(parent, exp.Dot):
parent = parent.parent
if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get(
"is_table"
):
expression.set("this", expression.this.lower())
return expression
@ -298,10 +363,8 @@ class BigQuery(Dialect):
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@ -325,7 +388,12 @@ class BigQuery(Dialect):
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
[
transforms.explode_to_unnest,
_unqualify_unnest,
transforms.eliminate_distinct_on,
_alias_ordered_group,
]
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
@ -334,7 +402,6 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@ -378,7 +445,121 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
RESERVED_KEYWORDS = {
*generator.Generator.RESERVED_KEYWORDS,
"all",
"and",
"any",
"array",
"as",
"asc",
"assert_rows_modified",
"at",
"between",
"by",
"case",
"cast",
"collate",
"contains",
"create",
"cross",
"cube",
"current",
"default",
"define",
"desc",
"distinct",
"else",
"end",
"enum",
"escape",
"except",
"exclude",
"exists",
"extract",
"false",
"fetch",
"following",
"for",
"from",
"full",
"group",
"grouping",
"groups",
"hash",
"having",
"if",
"ignore",
"in",
"inner",
"intersect",
"interval",
"into",
"is",
"join",
"lateral",
"left",
"like",
"limit",
"lookup",
"merge",
"natural",
"new",
"no",
"not",
"null",
"nulls",
"of",
"on",
"or",
"order",
"outer",
"over",
"partition",
"preceding",
"proto",
"qualify",
"range",
"recursive",
"respect",
"right",
"rollup",
"rows",
"select",
"set",
"some",
"struct",
"tablesample",
"then",
"to",
"treat",
"true",
"unbounded",
"union",
"unnest",
"using",
"when",
"where",
"window",
"with",
"within",
}
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
if not isinstance(expression.parent, exp.Cast):
return self.func(
"TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone"))
)
return super().attimezone_sql(expression)
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
def cte_sql(self, expression: exp.CTE) -> str:
if expression.alias_column_names:
self.unsupported("Column names in CTE definition are not supported.")
return super().cte_sql(expression)
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)

View file

@ -388,6 +388,11 @@ def no_comment_column_constraint_sql(
return ""
def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str:
self.unsupported("MAP_FROM_ENTRIES unsupported")
return ""
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")

View file

@ -132,6 +132,10 @@ class MySQL(Dialect):
"SEPARATOR": TokenType.SEPARATOR,
"ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
"SIGNED": TokenType.BIGINT,
"SIGNED INTEGER": TokenType.BIGINT,
"UNSIGNED": TokenType.UBIGINT,
"UNSIGNED INTEGER": TokenType.UBIGINT,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@ -441,6 +445,17 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
if expression.to.this == exp.DataType.Type.BIGINT:
to = "SIGNED"
elif expression.to.this == exp.DataType.Type.UBIGINT:
to = "UNSIGNED"
else:
return super().cast_sql(expression)
return f"CAST({self.sql(expression, 'this')} AS {to})"
def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""

View file

@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
max_or_greatest,
min_or_least,
no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
no_tablesample_sql,
@ -346,6 +347,7 @@ class Postgres(Dialect):
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@ -378,3 +380,11 @@ class Postgres(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def bracket_sql(self, expression: exp.Bracket) -> str:
"""Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY."""
if isinstance(expression.this, exp.Array):
expression = expression.copy()
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)

View file

@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import (
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
def _parse_element_at(args: t.List) -> exp.SafeBracket:
this = seq_get(args, 0)
index = seq_get(args, 1)
assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1))
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@ -201,6 +208,7 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"ELEMENT_AT": _parse_element_at,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@ -285,6 +293,9 @@ class Presto(Dialect):
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.Right: right_to_substring_sql,
exp.SafeBracket: lambda self, e: self.func(
"ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0)
),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(

View file

@ -41,8 +41,6 @@ class Redshift(Postgres):
"STRTOL": exp.FromBase.from_arg_list,
}
CONVERT_TYPE_FIRST = True
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@ -58,6 +56,12 @@ class Redshift(Postgres):
return this
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_bitwise()
return self.expression(exp.TryCast, this=this, to=to)
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []

View file

@ -258,14 +258,29 @@ class Snowflake(Dialect):
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
"SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")),
"UNSET": lambda self: self.expression(
exp.Set,
tag=self._match_text_seq("TAG"),
expressions=self._parse_csv(self._parse_id_var),
unset=True,
),
}
def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
self._match_text_seq("TAG")
parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
def _parse_id_var(
self,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
if self._match_text_seq("IDENTIFIER", "("):
identifier = (
super()._parse_id_var(any_token=any_token, tokens=tokens)
or self._parse_string()
)
self._match_r_paren()
return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier])
return super()._parse_id_var(any_token=any_token, tokens=tokens)
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
@ -380,10 +395,6 @@ class Snowflake(Dialect):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
def describe_sql(self, expression: exp.Describe) -> str:
# Default to table if kind is unknown
kind_value = expression.args.get("kind") or "TABLE"

View file

@ -43,6 +43,7 @@ class Spark(Spark2):
class Generator(Spark2.Generator):
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")

View file

@ -231,14 +231,14 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.is_type("json"):
return self.func("TO_JSON", expression.this)
return super(Hive.Generator, self).cast_sql(expression)
return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -133,7 +135,7 @@ class SQLite(Dialect):
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type("date"):
return self.func("DATE", expression.this)

View file

@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
class TSQL(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"

View file

@ -63,11 +63,9 @@ class Context:
reader = table[i]
yield reader, self
def table_iter(self, table: str) -> t.Iterator[t.Tuple[TableIter, Context]]:
def table_iter(self, table: str) -> TableIter:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
return iter(self.tables[table])
def filter(self, condition) -> None:
rows = [reader.row for reader, _ in self if self.eval(condition)]

View file

@ -276,11 +276,9 @@ class PythonExecutor:
end = 1
length = len(context.table)
table = self.table(list(step.group) + step.aggregations)
condition = self.generate(step.condition)
def add_row():
if not condition or context.eval(condition):
table.append(group + context.eval_tuple(aggregations))
table.append(group + context.eval_tuple(aggregations))
if length:
for i in range(length):
@ -304,7 +302,7 @@ class PythonExecutor:
context = self.context({step.name: table, **{name: table for name in context.tables}})
if step.projections:
if step.projections or step.condition:
return self.scan(step, context)
return context

View file

@ -1013,7 +1013,7 @@ class Pragma(Expression):
class Set(Expression):
arg_types = {"expressions": False}
arg_types = {"expressions": False, "unset": False, "tag": False}
class SetItem(Expression):
@ -1168,10 +1168,6 @@ class RenameTable(Expression):
pass
class SetTag(Expression):
arg_types = {"expressions": True, "unset": False}
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
@ -1934,6 +1930,11 @@ class LanguageProperty(Property):
arg_types = {"this": True}
# spark ddl
class ClusteredByProperty(Property):
arg_types = {"expressions": True, "sorted_by": False, "buckets": True}
class DictProperty(Property):
arg_types = {"this": True, "kind": True, "settings": False}
@ -2074,6 +2075,7 @@ class Properties(Expression):
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"CLUSTERED_BY": ClusteredByProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
@ -2280,6 +2282,12 @@ class Table(Expression):
"system_time": False,
}
@property
def name(self) -> str:
if isinstance(self.this, Func):
return ""
return self.this.name
@property
def db(self) -> str:
return self.text("db")
@ -3716,6 +3724,10 @@ class Bracket(Condition):
arg_types = {"this": True, "expressions": True}
class SafeBracket(Bracket):
"""Represents array lookup where OOB index yields NULL instead of causing a failure."""
class Distinct(Expression):
arg_types = {"expressions": False, "on": False}
@ -3934,7 +3946,7 @@ class Case(Func):
class Cast(Func):
arg_types = {"this": True, "to": True}
arg_types = {"this": True, "to": True, "format": False}
@property
def name(self) -> str:
@ -4292,6 +4304,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
class MapFromEntries(Func):
pass
class StarMap(Func):
pass

View file

@ -188,6 +188,7 @@ class Generator:
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA,
exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA,
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
@ -1408,7 +1409,8 @@ class Generator:
expressions = (
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
)
return f"SET{expressions}"
tag = " TAG" if expression.args.get("tag") else ""
return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}"
def pragma_sql(self, expression: exp.Pragma) -> str:
return f"PRAGMA {self.sql(expression, 'this')}"
@ -1749,6 +1751,9 @@ class Generator:
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
def safebracket_sql(self, expression: exp.SafeBracket) -> str:
return self.bracket_sql(expression)
def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
@ -2000,8 +2005,10 @@ class Generator:
def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
def cast_sql(self, expression: exp.Cast) -> str:
return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@ -2227,7 +2234,7 @@ class Generator:
return self.binary(expression, "-")
def trycast_sql(self, expression: exp.TryCast) -> str:
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
return self.cast_sql(expression, safe_prefix="TRY_")
def use_sql(self, expression: exp.Use) -> str:
kind = self.sql(expression, "kind")
@ -2409,6 +2416,13 @@ class Generator:
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return ""
def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str:
expressions = self.expressions(expression, key="expressions", flat=True)
sorted_by = self.expressions(expression, key="sorted_by", flat=True)
sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else ""
buckets = self.sql(expression, "buckets")
return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS"
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -60,8 +60,8 @@ def qualify(
The qualified expression.
"""
schema = ensure_schema(schema, dialect=dialect)
expression = normalize_identifiers(expression, dialect=dialect)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
expression = normalize_identifiers(expression, dialect=dialect)
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

View file

@ -56,13 +56,13 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
return expression
def validate_qualify_columns(expression):
def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
@ -79,7 +79,7 @@ def validate_qualify_columns(expression):
return expression
def _pop_table_column_aliases(derived_tables):
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
"""
Remove table column aliases.
@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables):
table_alias.args.pop("columns", None)
def _expand_using(scope, resolver):
def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
joins = list(scope.find_all(exp.Join))
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
for join in joins:
using = join.args.get("using")
@ -172,20 +172,25 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
alias_to_expression: t.Dict[str, exp.Expression] = {}
def replace_columns(
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
):
def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
if not node:
return
for column, *_ in walk_in_scope(node):
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_agg and not column.table else None
if table and column.find_ancestor(exp.AggFunc):
table = resolver.get_table(column.name) if resolve_table and not column.table else None
alias_expr = alias_to_expression.get(column.name)
double_agg = (
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
if alias_expr
else False
)
if table and (not alias_expr or double_agg):
column.set("table", table)
elif expand and not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
elif not column.table and alias_expr and not double_agg:
column.replace(alias_expr.copy())
for projection in scope.selects:
replace_columns(projection)
@ -195,22 +200,41 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"))
replace_columns(expression.args.get("having"), resolve_agg=True)
replace_columns(expression.args.get("qualify"), resolve_agg=True)
replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
def _expand_group_by(scope: Scope):
expression = scope.expression
group = expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
expression.set("group", group)
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
groups = set(group.expressions)
group.meta["final"] = True
for e in expression.selects:
for node, *_ in e.walk():
if node in groups:
e.meta["final"] = True
break
having = expression.args.get("having")
if having:
for node, *_ in having.walk():
if node in groups:
having.meta["final"] = True
break
def _expand_order_by(scope):
def _expand_order_by(scope: Scope, resolver: Resolver):
order = scope.expression.args.get("order")
if not order:
return
@ -220,10 +244,21 @@ def _expand_order_by(scope):
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds)),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
if not col.table:
col.set("table", resolver.get_table(col.name))
ordered.set("this", new_expression)
if scope.expression.args.get("group"):
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
def _expand_positional_references(scope, expressions):
for ordered in ordereds:
ordered.set("this", selects.get(ordered.this, ordered.this))
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
@ -241,7 +276,7 @@ def _expand_positional_references(scope, expressions):
return new_nodes
def _qualify_columns(scope, resolver):
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
@ -290,21 +325,23 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
def _expand_stars(scope, resolver, using_column_tables):
def _expand_stars(
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
) -> None:
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
if pivot and has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
@ -330,8 +367,17 @@ def _expand_stars(scope, resolver, using_column_tables):
columns = resolver.get_source_columns(table, only_visible=True)
# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
if resolver.schema.dialect == "bigquery":
columns = [
name
for name in columns
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
]
if columns and "*" not in columns:
if has_pivoted_source:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
@ -368,7 +414,9 @@ def _expand_stars(scope, resolver, using_column_tables):
scope.expression.set("expressions", new_selections)
def _add_except_columns(expression, tables, except_columns):
def _add_except_columns(
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
) -> None:
except_ = expression.args.get("except")
if not except_:
@ -380,7 +428,9 @@ def _add_except_columns(expression, tables, except_columns):
except_columns[id(table)] = columns
def _add_replace_columns(expression, tables, replace_columns):
def _add_replace_columns(
expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
) -> None:
replace = expression.args.get("replace")
if not replace:
@ -392,7 +442,7 @@ def _add_replace_columns(expression, tables, replace_columns):
replace_columns[id(table)] = columns
def _qualify_outputs(scope):
def _qualify_outputs(scope: Scope):
"""Ensure all output columns are aliased"""
new_selections = []
@ -429,7 +479,7 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""
def __init__(self, scope, schema, infer_schema: bool = True):
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None

View file

@ -28,6 +28,8 @@ def simplify(expression):
generate = cached_generator()
def _simplify(expression, root=True):
if expression.meta.get("final"):
return expression
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)

View file

@ -585,6 +585,7 @@ class Parser(metaclass=_Parser):
"CHARACTER SET": lambda self: self._parse_character_set(),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"CLUSTERED": lambda self: self._parse_clustered_by(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
@ -794,8 +795,6 @@ class Parser(metaclass=_Parser):
# A NULL arg in CONCAT yields NULL by default
CONCAT_NULL_OUTPUTS_STRING = False
CONVERT_TYPE_FIRST = False
PREFIXED_PIVOT_COLUMNS = False
IDENTIFY_PIVOT_STRINGS = False
@ -1426,9 +1425,34 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
def _parse_cluster(self) -> t.Optional[exp.Cluster]:
def _parse_cluster(self) -> exp.Cluster:
return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
def _parse_clustered_by(self) -> exp.ClusteredByProperty:
self._match_text_seq("BY")
self._match_l_paren()
expressions = self._parse_csv(self._parse_column)
self._match_r_paren()
if self._match_text_seq("SORTED", "BY"):
self._match_l_paren()
sorted_by = self._parse_csv(self._parse_ordered)
self._match_r_paren()
else:
sorted_by = None
self._match(TokenType.INTO)
buckets = self._parse_number()
self._match_text_seq("BUCKETS")
return self.expression(
exp.ClusteredByProperty,
expressions=expressions,
sorted_by=sorted_by,
buckets=buckets,
)
def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]:
if not self._match_text_seq("GRANTS"):
self._retreat(self._index - 1)
@ -2863,7 +2887,11 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.INTERVAL):
return None
this = self._parse_primary() or self._parse_term()
if self._match(TokenType.STRING, advance=False):
this = self._parse_primary()
else:
this = self._parse_term()
unit = self._parse_function() or self._parse_var()
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
@ -3661,6 +3689,7 @@ class Parser(metaclass=_Parser):
else:
self.raise_error("Expected AS after CAST")
fmt = None
to = self._parse_types()
if not to:
@ -3668,22 +3697,23 @@ class Parser(metaclass=_Parser):
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT):
fmt = self._parse_string()
elif self._match(TokenType.FORMAT):
fmt = self._parse_at_time_zone(self._parse_string())
return self.expression(
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
this=this,
format=exp.Literal.string(
format_time(
fmt.this if fmt else "",
self.FORMAT_MAPPING or self.TIME_MAPPING,
self.FORMAT_TRIE or self.TIME_TRIE,
)
),
)
if to.this in exp.DataType.TEMPORAL_TYPES:
return self.expression(
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
this=this,
format=exp.Literal.string(
format_time(
fmt.this if fmt else "",
self.FORMAT_MAPPING or self.TIME_MAPPING,
self.FORMAT_TRIE or self.TIME_TRIE,
)
),
)
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
def _parse_concat(self) -> t.Optional[exp.Expression]:
args = self._parse_csv(self._parse_conjunction)
@ -3704,20 +3734,23 @@ class Parser(metaclass=_Parser):
)
def _parse_string_agg(self) -> exp.Expression:
expression: t.Optional[exp.Expression]
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
args: t.List[t.Optional[exp.Expression]] = [
self.expression(exp.Distinct, expressions=[self._parse_conjunction()])
]
if self._match(TokenType.COMMA):
args.extend(self._parse_csv(self._parse_conjunction))
else:
args = self._parse_csv(self._parse_conjunction)
expression = seq_get(args, 0)
index = self._index
if not self._match(TokenType.R_PAREN):
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
return self.expression(
exp.GroupConcat,
this=seq_get(args, 0),
separator=self._parse_order(this=seq_get(args, 1)),
)
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
@ -3727,24 +3760,21 @@ class Parser(metaclass=_Parser):
return self.validate_expression(exp.GroupConcat.from_arg_list(args), args)
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
order = self._parse_order(this=seq_get(args, 0))
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to: t.Optional[exp.Expression]
this = self._parse_bitwise()
if self._match(TokenType.USING):
to = self.expression(exp.CharacterSet, this=self._parse_var())
to: t.Optional[exp.Expression] = self.expression(
exp.CharacterSet, this=self._parse_var()
)
elif self._match(TokenType.COMMA):
to = self._parse_bitwise()
to = self._parse_types()
else:
to = None
# Swap the argument order if needed to produce the correct AST
if self.CONVERT_TYPE_FIRST:
this, to = to, this
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
@ -4394,8 +4424,8 @@ class Parser(metaclass=_Parser):
if self._next:
self._advance()
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
actions = ensure_list(parser(self))
@ -4516,9 +4546,11 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE)
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
def _parse_set(self) -> exp.Set | exp.Command:
def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command:
index = self._index
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
set_ = self.expression(
exp.Set, expressions=self._parse_csv(self._parse_set_item), unset=unset, tag=tag
)
if self._curr:
self._retreat(index)
@ -4683,12 +4715,8 @@ class Parser(metaclass=_Parser):
exp.replace_children(this, self._replace_columns_with_dots)
table = this.args.get("table")
this = (
self.expression(exp.Dot, this=table, expression=this.this)
if table
else self.expression(exp.Var, this=this.name)
self.expression(exp.Dot, this=table, expression=this.this) if table else this.this
)
elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name)
return this

View file

@ -91,6 +91,7 @@ class Step:
A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
expression = expression.unnest()
with_ = expression.args.get("with")
# CTEs break the mold of scope and introduce themselves to all in the context.
@ -120,22 +121,25 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = []
aggregations = set()
next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
for agg in expression.find_all(exp.AggFunc):
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations.add(expression)
for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()
operand.replace(exp.column(operands[operand], quoted=True))
return bool(agg_funcs)
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
aggregations.append(e)
extract_agg_operands(e)
else:
projections.append(e)
@ -155,22 +159,38 @@ class Step:
having = expression.args.get("having")
if having:
extract_agg_operands(having)
aggregate.condition = having.this
if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)):
aggregate.condition = exp.column("_h", step.name, quoted=True)
else:
aggregate.condition = having.this
aggregate.operands = tuple(
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
aggregate.aggregations = list(aggregations)
# give aggregates names and replace projections with references to them
aggregate.group = {
f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
}
intermediate: t.Dict[str | exp.Expression, str] = {}
for k, v in aggregate.group.items():
intermediate[v] = k
if isinstance(v, exp.Column):
intermediate[v.alias_or_name] = k
for projection in projections:
for i, e in aggregate.group.items():
for child, *_ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
for node, *_ in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))
if aggregate.condition:
for node, *_ in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
if name:
node.replace(exp.column(name, step.name))
aggregate.add_dependency(step)
step = aggregate

View file

@ -159,10 +159,11 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import build_scope
taken_select_names = set(expression.named_selects)
scope = build_scope(expression)
if not scope:
return expression
taken_select_names = set(expression.named_selects)
taken_source_names = set(scope.selected_sources)
for select in expression.selects: