1
0
Fork 0

Merging upstream version 25.0.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:37:40 +01:00
parent 03b67e2ec9
commit 021892b3ff
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
84 changed files with 33016 additions and 31040 deletions

View file

@ -70,12 +70,14 @@ from sqlglot.dialects.doris import Doris
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.materialize import Materialize
from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
from sqlglot.dialects.prql import PRQL
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.risingwave import RisingWave
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.spark2 import Spark2

View file

@ -705,7 +705,6 @@ class BigQuery(Dialect):
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
RESERVED_KEYWORDS = {
*generator.Generator.RESERVED_KEYWORDS,
"all",
"and",
"any",

View file

@ -367,7 +367,7 @@ class ClickHouse(Dialect):
**parser.Parser.QUERY_MODIFIER_PARSERS,
TokenType.SETTINGS: lambda self: (
"settings",
self._advance() or self._parse_csv(self._parse_conjunction),
self._advance() or self._parse_csv(self._parse_assignment),
),
TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
}
@ -388,15 +388,15 @@ class ClickHouse(Dialect):
"INDEX",
}
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = super()._parse_conjunction()
def _parse_assignment(self) -> t.Optional[exp.Expression]:
this = super()._parse_assignment()
if self._match(TokenType.PLACEHOLDER):
return self.expression(
exp.If,
this=this,
true=self._parse_conjunction(),
false=self._match(TokenType.COLON) and self._parse_conjunction(),
true=self._parse_assignment(),
false=self._match(TokenType.COLON) and self._parse_assignment(),
)
return this
@ -461,7 +461,7 @@ class ClickHouse(Dialect):
# WITH <expression> AS <identifier>
cte = self.expression(
exp.CTE,
this=self._parse_conjunction(),
this=self._parse_assignment(),
alias=self._parse_table_alias(),
scalar=True,
)
@ -592,7 +592,7 @@ class ClickHouse(Dialect):
) -> exp.IndexColumnConstraint:
# INDEX name1 expr TYPE type1(args) GRANULARITY value
this = self._parse_id_var()
expression = self._parse_conjunction()
expression = self._parse_assignment()
index_type = self._match_text_seq("TYPE") and (
self._parse_function() or self._parse_var()

View file

@ -50,12 +50,14 @@ class Dialects(str, Enum):
DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
MATERIALIZE = "materialize"
MYSQL = "mysql"
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
PRQL = "prql"
REDSHIFT = "redshift"
RISINGWAVE = "risingwave"
SNOWFLAKE = "snowflake"
SPARK = "spark"
SPARK2 = "spark2"
@ -593,7 +595,9 @@ def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
exp.Like(
this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression)
)
)

View file

@ -26,6 +26,9 @@ class Doris(MySQL):
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
FUNCTION_PARSERS = MySQL.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("GROUP_CONCAT")
class Generator(MySQL.Generator):
LAST_DAY_SUPPORTS_DATE_PART = False
@ -49,6 +52,9 @@ class Doris(MySQL):
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.GroupConcat: lambda self, e: self.func(
"GROUP_CONCAT", e.this, e.args.get("separator") or exp.Literal.string(",")
),
exp.JSONExtractScalar: lambda self, e: self.func("JSON_EXTRACT", e.this, e.expression),
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),

View file

@ -341,7 +341,7 @@ class DuckDB(Dialect):
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())
args = self._parse_wrapped_csv(self._parse_conjunction)
args = self._parse_wrapped_csv(self._parse_assignment)
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
@ -503,11 +503,93 @@ class DuckDB(Dialect):
exp.DataType.Type.VARBINARY: "BLOB",
exp.DataType.Type.ROWVERSION: "BLOB",
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
}
# https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77
RESERVED_KEYWORDS = {
"array",
"analyse",
"union",
"all",
"when",
"in_p",
"default",
"create_p",
"window",
"asymmetric",
"to",
"else",
"localtime",
"from",
"end_p",
"select",
"current_date",
"foreign",
"with",
"grant",
"session_user",
"or",
"except",
"references",
"fetch",
"limit",
"group_p",
"leading",
"into",
"collate",
"offset",
"do",
"then",
"localtimestamp",
"check_p",
"lateral_p",
"current_role",
"where",
"asc_p",
"placing",
"desc_p",
"user",
"unique",
"initially",
"column",
"both",
"some",
"as",
"any",
"only",
"deferrable",
"null_p",
"current_time",
"true_p",
"table",
"case",
"trailing",
"variadic",
"for",
"on",
"distinct",
"false_p",
"not",
"constraint",
"current_timestamp",
"returning",
"primary",
"intersect",
"having",
"analyze",
"current_user",
"and",
"cast",
"symmetric",
"using",
"order",
"current_catalog",
}
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties

View file

@ -29,6 +29,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
time_format,
timestrtotime_sql,
unit_to_str,
var_map_sql,
)
from sqlglot.transforms import (
@ -318,6 +319,7 @@ class Hive(Dialect):
),
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"TRUNC": exp.TimestampTrunc.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
args or [exp.CurrentTimestamp()]
@ -415,7 +417,7 @@ class Hive(Dialect):
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
(
self._parse_csv(self._parse_conjunction)
self._parse_csv(self._parse_assignment)
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
else []
),
@ -548,6 +550,7 @@ class Hive(Dialect):
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),

View file

@ -0,0 +1,94 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.helper import seq_get
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
from sqlglot.transforms import (
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
preprocess,
)
import typing as t
class Materialize(Postgres):
class Parser(Postgres.Parser):
NO_PAREN_FUNCTION_PARSERS = {
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
"MAP": lambda self: self._parse_map(),
}
LAMBDAS = {
**Postgres.Parser.LAMBDAS,
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg, this=seq_get(expressions, 0), expression=self._parse_assignment()
),
}
def _parse_lambda_arg(self) -> t.Optional[exp.Expression]:
return self._parse_field()
def _parse_map(self) -> exp.ToMap:
if self._match(TokenType.L_PAREN):
to_map = self.expression(exp.ToMap, this=self._parse_select())
self._match_r_paren()
return to_map
if not self._match(TokenType.L_BRACKET):
self.raise_error("Expecting [")
entries = [
exp.PropertyEQ(this=e.this, expression=e.expression)
for e in self._parse_csv(self._parse_lambda)
]
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expecting ]")
return self.expression(exp.ToMap, this=self.expression(exp.Struct, expressions=entries))
class Generator(Postgres.Generator):
SUPPORTS_CREATE_TABLE_LIKE = False
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
exp.AutoIncrementColumnConstraint: lambda self, e: "",
exp.Create: preprocess(
[
remove_unique_constraints,
ctas_with_tmp_tables_to_create_tmp_view,
]
),
exp.GeneratedAsIdentityColumnConstraint: lambda self, e: "",
exp.OnConflict: lambda self, e: "",
exp.PrimaryKeyColumnConstraint: lambda self, e: "",
}
TRANSFORMS.pop(exp.ToMap)
def propertyeq_sql(self, expression: exp.PropertyEQ) -> str:
return self.binary(expression, "=>")
def datatype_sql(self, expression: exp.DataType) -> str:
if expression.is_type(exp.DataType.Type.LIST):
if expression.expressions:
return f"{self.expressions(expression, flat=True)} LIST"
return "LIST"
if expression.is_type(exp.DataType.Type.MAP) and len(expression.expressions) == 2:
key, value = expression.expressions
return f"MAP[{self.sql(key)} => {self.sql(value)}]"
return super().datatype_sql(expression)
def list_sql(self, expression: exp.List) -> str:
if isinstance(seq_get(expression.expressions, 0), exp.Select):
return self.func("LIST", seq_get(expression.expressions, 0))
return f"{self.normalize_func('LIST')}[{self.expressions(expression, flat=True)}]"
def tomap_sql(self, expression: exp.ToMap) -> str:
if isinstance(expression.this, exp.Select):
return self.func("MAP", expression.this)
return f"{self.normalize_func('MAP')}[{self.expressions(expression.this)}]"

View file

@ -279,6 +279,10 @@ class MySQL(Dialect):
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
TokenType.XOR: exp.Xor,
}
DISJUNCTION = {
**parser.Parser.DISJUNCTION,
TokenType.DPIPE: exp.Or,
}
@ -625,7 +629,7 @@ class MySQL(Dialect):
)
def _parse_chr(self) -> t.Optional[exp.Expression]:
expressions = self._parse_csv(self._parse_conjunction)
expressions = self._parse_csv(self._parse_assignment)
kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)}
if len(expressions) > 1:

View file

@ -114,15 +114,6 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
if expression.expressions:
values = self.expressions(expression, key="values", flat=True)
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return self.datatype_sql(expression)
def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:
auto = expression.find(exp.AutoIncrementColumnConstraint)
@ -500,7 +491,6 @@ class Postgres(Dialect):
exp.DateAdd: _date_add_sql("+"),
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
@ -623,3 +613,11 @@ class Postgres(Dialect):
option = self.sql(expression, "option")
return f"SET {exprs}{access_method}{tablespace}{option}"
def datatype_sql(self, expression: exp.DataType) -> str:
if expression.is_type(exp.DataType.Type.ARRAY):
if expression.expressions:
values = self.expressions(expression, key="values", flat=True)
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return super().datatype_sql(expression)

View file

@ -36,6 +36,10 @@ class PRQL(Dialect):
CONJUNCTION = {
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
}
DISJUNCTION = {
**parser.Parser.DISJUNCTION,
TokenType.DPIPE: exp.Or,
}
@ -43,7 +47,7 @@ class PRQL(Dialect):
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
"FILTER": lambda self, query: query.where(self._parse_conjunction()),
"FILTER": lambda self, query: query.where(self._parse_assignment()),
"APPEND": lambda self, query: query.union(
_select_all(self._parse_table()), distinct=False, copy=False
),
@ -174,8 +178,8 @@ class PRQL(Dialect):
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
self._match(TokenType.ALIAS)
return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
return self._parse_conjunction()
return self.expression(exp.Alias, this=self._parse_assignment(), alias=alias)
return self._parse_assignment()
def _parse_table(
self,

View file

@ -0,0 +1,6 @@
from sqlglot.dialects.postgres import Postgres
class RisingWave(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False

View file

@ -498,7 +498,7 @@ class Snowflake(Dialect):
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
this=self._replace_lambda(
self._parse_conjunction(),
self._parse_assignment(),
expressions,
),
expressions=[e.this if isinstance(e, exp.Cast) else e for e in expressions],
@ -576,7 +576,7 @@ class Snowflake(Dialect):
# - https://docs.snowflake.com/en/sql-reference/functions/object_construct
return self._parse_slice(self._parse_string())
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
return self._parse_slice(self._parse_alias(self._parse_assignment(), explicit=True))
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
lateral = super()._parse_lateral()
@ -714,7 +714,7 @@ class Snowflake(Dialect):
def _parse_file_location(self) -> t.Optional[exp.Expression]:
# Parse either a subquery or a staged file
return (
self._parse_select(table=True)
self._parse_select(table=True, parse_subquery_alias=False)
if self._match(TokenType.L_PAREN, advance=False)
else self._parse_table_parts()
)

View file

@ -164,7 +164,7 @@ class Teradata(Dialect):
}
def _parse_translate(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
this = self._parse_assignment()
if not self._match(TokenType.USING):
self.raise_error("Expected USING in TRANSLATE")
@ -195,8 +195,8 @@ class Teradata(Dialect):
this = self._parse_id_var()
self._match(TokenType.BETWEEN)
expressions = self._parse_csv(self._parse_conjunction)
each = self._match_text_seq("EACH") and self._parse_conjunction()
expressions = self._parse_csv(self._parse_assignment)
each = self._match_text_seq("EACH") and self._parse_assignment()
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)

View file

@ -625,7 +625,7 @@ class TSQL(Dialect):
) -> t.Optional[exp.Expression]:
this = self._parse_types()
self._match(TokenType.COMMA)
args = [this, *self._parse_csv(self._parse_conjunction)]
args = [this, *self._parse_csv(self._parse_assignment)]
convert = exp.Convert.from_arg_list(args)
convert.set("safe", safe)
convert.set("strict", strict)

View file

@ -3977,6 +3977,7 @@ class DataType(Expression):
IPV6 = auto()
JSON = auto()
JSONB = auto()
LIST = auto()
LONGBLOB = auto()
LONGTEXT = auto()
LOWCARDINALITY = auto()
@ -4768,6 +4769,12 @@ class ToArray(Func):
pass
# https://materialize.com/docs/sql/types/list/
class List(Func):
arg_types = {"expressions": False}
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/to_char
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
class ToChar(Func):
@ -5245,6 +5252,18 @@ class ToBase64(Func):
pass
class GapFill(Func):
arg_types = {
"this": True,
"ts_column": True,
"bucket_width": True,
"partitioning_columns": False,
"value_columns": False,
"origin": False,
"ignore_nulls": False,
}
class GenerateDateArray(Func):
arg_types = {"start": True, "end": True, "interval": False}
@ -6175,6 +6194,8 @@ def _apply_child_list_builder(
):
instance = maybe_copy(instance, copy)
parsed = []
properties = {} if properties is None else properties
for expression in expressions:
if expression is not None:
if _is_wrong_expression(expression, into):
@ -6187,14 +6208,18 @@ def _apply_child_list_builder(
prefix=prefix,
**opts,
)
parsed.extend(expression.expressions)
for k, v in expression.args.items():
if k == "expressions":
parsed.extend(v)
else:
properties[k] = v
existing = instance.args.get(arg)
if append and existing:
parsed = existing.expressions + parsed
child = into(expressions=parsed)
for k, v in (properties or {}).items():
for k, v in properties.items():
child.set(k, v)
instance.set(arg, child)

View file

@ -3955,3 +3955,8 @@ class Generator(metaclass=_Generator):
expressions = self.expressions(expression, flat=True)
expressions = f" USING ({expressions})" if expressions else ""
return f"MASKING POLICY {this}{expressions}"
def gapfill_sql(self, expression: exp.GapFill) -> str:
this = self.sql(expression, "this")
this = f"TABLE {this}"
return self.func("GAP_FILL", this, *[v for k, v in expression.args.items() if k != "this"])

View file

@ -36,7 +36,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
original = node.copy()
node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
distance = normalization_distance(node, dnf=dnf, max_=max_distance)
if distance > max_distance:
logger.info(
@ -85,7 +85,9 @@ def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
)
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
def normalization_distance(
expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
) -> int:
"""
The difference in the number of predicates between a given expression and its normalized form.
@ -101,33 +103,47 @@ def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int
expression: The expression to compute the normalization distance for.
dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
max_: stop early if count exceeds this.
Returns:
The normalization distance.
"""
return sum(_predicate_lengths(expression, dnf)) - (
sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
for length in _predicate_lengths(expression, dnf, max_):
total += length
if total > max_:
return total
return total
def _predicate_lengths(expression, dnf):
def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
"""
Returns a list of predicate lengths when expanded to normalized form.
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
"""
if depth > max_:
yield depth
return
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return (1,)
yield 1
return
depth += 1
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
)
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
for a in _predicate_lengths(left, dnf, max_, depth):
for b in _predicate_lengths(right, dnf, max_, depth):
yield a + b
else:
yield from _predicate_lengths(left, dnf, max_, depth)
yield from _predicate_lengths(right, dnf, max_, depth)
def distributive_law(expression, dnf, max_distance):
@ -138,7 +154,7 @@ def distributive_law(expression, dnf, max_distance):
if normalized(expression, dnf=dnf):
return expression
distance = normalization_distance(expression, dnf=dnf)
distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")

View file

@ -80,7 +80,7 @@ def qualify_columns(
)
qualify_outputs(scope)
_expand_group_by(scope)
_expand_group_by(scope, dialect)
_expand_order_by(scope, resolver)
if dialect == "bigquery":
@ -266,13 +266,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
scope.clear_cache()
def _expand_group_by(scope: Scope) -> None:
def _expand_group_by(scope: Scope, dialect: DialectType) -> None:
expression = scope.expression
group = expression.args.get("group")
if not group:
return
group.set("expressions", _expand_positional_references(scope, group.expressions))
group.set("expressions", _expand_positional_references(scope, group.expressions, dialect))
expression.set("group", group)
@ -284,7 +284,9 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds), alias=True),
_expand_positional_references(
scope, (o.this for o in ordereds), resolver.schema.dialect, alias=True
),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
@ -307,9 +309,11 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
def _expand_positional_references(
scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False
) -> t.List[exp.Expression]:
new_nodes: t.List[exp.Expression] = []
ambiguous_projections = None
for node in expressions:
if node.is_int:
select = _select_by_pos(scope, t.cast(exp.Literal, node))
@ -319,7 +323,28 @@ def _expand_positional_references(
else:
select = select.this
if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest):
if dialect == "bigquery":
if ambiguous_projections is None:
# When a projection name is also a source name and it is referenced in the
# GROUP BY clause, BQ can't understand what the identifier corresponds to
ambiguous_projections = {
s.alias_or_name
for s in scope.expression.selects
if s.alias_or_name in scope.selected_sources
}
ambiguous = any(
column.parts[0].name in ambiguous_projections
for column in select.find_all(exp.Column)
)
else:
ambiguous = False
if (
isinstance(select, exp.CONSTANTS)
or select.find(exp.Explode, exp.Unnest)
or ambiguous
):
new_nodes.append(node)
else:
new_nodes.append(select.copy())

View file

@ -1,10 +1,11 @@
from __future__ import annotations
import datetime
import logging
import functools
import itertools
import typing as t
from collections import deque
from collections import deque, defaultdict
from decimal import Decimal
from functools import reduce
@ -20,6 +21,8 @@ if t.TYPE_CHECKING:
[exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
]
logger = logging.getLogger("sqlglot")
# Final means that an expression should not be simplified
FINAL = "final"
@ -35,7 +38,10 @@ class UnsupportedUnit(Exception):
def simplify(
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
expression: exp.Expression,
constant_propagation: bool = False,
dialect: DialectType = None,
max_depth: t.Optional[int] = None,
):
"""
Rewrite sqlglot AST to simplify expressions.
@ -47,9 +53,9 @@ def simplify(
'TRUE'
Args:
expression (sqlglot.Expression): expression to simplify
expression: expression to simplify
constant_propagation: whether the constant propagation rule should be used
max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
Returns:
sqlglot.Expression: simplified expression
"""
@ -57,6 +63,18 @@ def simplify(
dialect = Dialect.get_or_raise(dialect)
def _simplify(expression, root=True):
if (
max_depth
and isinstance(expression, exp.Connector)
and not isinstance(expression.parent, exp.Connector)
):
depth = connector_depth(expression)
if depth > max_depth:
logger.info(
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
)
return expression
if expression.meta.get(FINAL):
return expression
@ -118,6 +136,33 @@ def simplify(
return expression
def connector_depth(expression: exp.Expression) -> int:
"""
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
"""
stack = deque([(expression, 0)])
max_depth = 0
while stack:
expression, depth = stack.pop()
if not isinstance(expression, exp.Connector):
continue
depth += 1
max_depth = max(depth, max_depth)
stack.append((expression.left, depth))
stack.append((expression.right, depth))
return max_depth
def catch(*exceptions):
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
@ -280,6 +325,7 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
}
NONDETERMINISTIC = (exp.Rand, exp.Randn)
AND_OR = (exp.And, exp.Or)
def _simplify_comparison(expression, left, right, or_=False):
@ -351,12 +397,12 @@ def remove_complements(expression, root=True):
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
ops = set(expression.flatten())
for op in ops:
if isinstance(op, exp.Not) and op.this in ops:
return exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return complement
return expression
@ -404,31 +450,63 @@ def absorb_and_eliminate(expression, root=True):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
if isinstance(a, kind):
aa, ab = a.unnest_operands()
ops = tuple(expression.flatten())
# absorb
if is_complement(b, aa):
aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
ba, bb = rhs
# Initialize lookup tables:
# Set of all operands, used to find complements for absorption.
op_set = set()
# Sub-operands, used to find subsets for absorption.
subops = defaultdict(list)
# Pairs of complements, used for elimination.
pairs = defaultdict(list)
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
a.replace(aa)
b.replace(aa)
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
a.replace(ab)
b.replace(ab)
# Populate the lookup tables
for op in ops:
op_set.add(op)
if not isinstance(op, kind):
# In cases like: A OR (A AND B)
# Subop will be: ^
subops[op].append({op})
continue
# In cases like: (A AND B) OR (A AND B AND C)
# Subops will be: ^ ^
subset = set(op.flatten())
for i in subset:
subops[i].append(subset)
a, b = op.unnest_operands()
if isinstance(a, exp.Not):
pairs[frozenset((a.this, b))].append((op, b))
if isinstance(b, exp.Not):
pairs[frozenset((a, b.this))].append((op, a))
for op in ops:
if not isinstance(op, kind):
continue
a, b = op.unnest_operands()
# Absorb
if isinstance(a, exp.Not) and a.this in op_set:
a.replace(exp.true() if kind == exp.And else exp.false())
continue
if isinstance(b, exp.Not) and b.this in op_set:
b.replace(exp.true() if kind == exp.And else exp.false())
continue
superset = set(op.flatten())
if any(any(subset < superset for subset in subops[i]) for i in superset):
op.replace(exp.false() if kind == exp.And else exp.true())
continue
# Eliminate
for other, complement in pairs[frozenset((a, b))]:
op.replace(complement)
other.replace(complement)
return expression

View file

@ -193,6 +193,7 @@ class Parser(metaclass=_Parser):
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.LIST,
TokenType.LOWCARDINALITY,
TokenType.MAP,
TokenType.NULLABLE,
@ -456,6 +457,11 @@ class Parser(metaclass=_Parser):
ALIAS_TOKENS = ID_VAR_TOKENS
ARRAY_CONSTRUCTORS = {
"ARRAY": exp.Array,
"LIST": exp.List,
}
COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS}
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
@ -504,8 +510,15 @@ class Parser(metaclass=_Parser):
*SUBQUERY_PREDICATES,
}
CONJUNCTION = {
CONJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = {
TokenType.AND: exp.And,
}
ASSIGNMENT: t.Dict[TokenType, t.Type[exp.Expression]] = {
TokenType.COLON_EQ: exp.PropertyEQ,
}
DISJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = {
TokenType.OR: exp.Or,
}
@ -588,7 +601,7 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
this=self._replace_lambda(
self._parse_conjunction(),
self._parse_assignment(),
expressions,
),
expressions=expressions,
@ -596,7 +609,7 @@ class Parser(metaclass=_Parser):
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg,
this=exp.var(expressions[0].name),
expression=self._parse_conjunction(),
expression=self._parse_assignment(),
),
}
@ -639,7 +652,7 @@ class Parser(metaclass=_Parser):
EXPRESSION_PARSERS = {
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.Condition: lambda self: self._parse_assignment(),
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
exp.Expression: lambda self: self._parse_expression(),
exp.From: lambda self: self._parse_from(joins=True),
@ -890,11 +903,11 @@ class Parser(metaclass=_Parser):
),
"CHECK": lambda self: self.expression(
exp.CheckColumnConstraint,
this=self._parse_wrapped(self._parse_conjunction),
this=self._parse_wrapped(self._parse_assignment),
enforced=self._match_text_seq("ENFORCED"),
),
"COLLATE": lambda self: self.expression(
exp.CollateColumnConstraint, this=self._parse_var()
exp.CollateColumnConstraint, this=self._parse_var(any_token=True)
),
"COMMENT": lambda self: self.expression(
exp.CommentColumnConstraint, this=self._parse_string()
@ -994,6 +1007,7 @@ class Parser(metaclass=_Parser):
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"GAP_FILL": lambda self: self._parse_gap_fill(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True),
"JSON_TABLE": lambda self: self._parse_json_table(),
@ -2191,7 +2205,7 @@ class Parser(metaclass=_Parser):
def _parse_partition_by(self) -> t.List[exp.Expression]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
return self._parse_csv(self._parse_assignment)
return []
def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec:
@ -2408,8 +2422,7 @@ class Parser(metaclass=_Parser):
stored=self._match_text_seq("STORED") and self._parse_stored(),
by_name=self._match_text_seq("BY", "NAME"),
exists=self._parse_exists(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
and self._parse_conjunction(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) and self._parse_assignment(),
expression=self._parse_derived_table_values() or self._parse_ddl_select(),
conflict=self._parse_on_conflict(),
returning=returning or self._parse_returning(),
@ -2619,7 +2632,7 @@ class Parser(metaclass=_Parser):
return None
return self.expression(
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction)
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_assignment)
)
def _parse_value(self) -> t.Optional[exp.Tuple]:
@ -3115,7 +3128,7 @@ class Parser(metaclass=_Parser):
kwargs["match_condition"] = self._parse_wrapped(self._parse_comparison)
if self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
kwargs["on"] = self._parse_assignment()
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
elif not isinstance(kwargs["this"], exp.Unnest) and not (
@ -3125,7 +3138,7 @@ class Parser(metaclass=_Parser):
joins: t.Optional[list] = list(self._parse_joins())
if joins and self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
kwargs["on"] = self._parse_assignment()
elif joins and self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
else:
@ -3138,7 +3151,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Join, comments=comments, **kwargs)
def _parse_opclass(self) -> t.Optional[exp.Expression]:
this = self._parse_conjunction()
this = self._parse_assignment()
if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
return this
@ -3554,7 +3567,7 @@ class Parser(metaclass=_Parser):
def _parse_pivot_in(self) -> exp.In:
def _parse_aliased_expression() -> t.Optional[exp.Expression]:
this = self._parse_conjunction()
this = self._parse_assignment()
self._match(TokenType.ALIAS)
alias = self._parse_field()
@ -3648,7 +3661,7 @@ class Parser(metaclass=_Parser):
return None
return self.expression(
exp.PreWhere, comments=self._prev_comments, this=self._parse_conjunction()
exp.PreWhere, comments=self._prev_comments, this=self._parse_assignment()
)
def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]:
@ -3656,7 +3669,7 @@ class Parser(metaclass=_Parser):
return None
return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
exp.Where, comments=self._prev_comments, this=self._parse_assignment()
)
def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]:
@ -3674,7 +3687,7 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(
lambda: None
if self._match(TokenType.ROLLUP, advance=False)
else self._parse_conjunction()
else self._parse_assignment()
)
if expressions:
elements["expressions"].extend(expressions)
@ -3725,18 +3738,18 @@ class Parser(metaclass=_Parser):
def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]:
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
return self.expression(exp.Having, this=self._parse_assignment())
def _parse_qualify(self) -> t.Optional[exp.Qualify]:
if not self._match(TokenType.QUALIFY):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
return self.expression(exp.Qualify, this=self._parse_assignment())
def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]:
if skip_start_token:
start = None
elif self._match(TokenType.START_WITH):
start = self._parse_conjunction()
start = self._parse_assignment()
else:
return None
@ -3745,11 +3758,11 @@ class Parser(metaclass=_Parser):
self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
exp.Prior, this=self._parse_bitwise()
)
connect = self._parse_conjunction()
connect = self._parse_assignment()
self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
if not start and self._match(TokenType.START_WITH):
start = self._parse_conjunction()
start = self._parse_assignment()
return self.expression(exp.Connect, start=start, connect=connect, nocycle=nocycle)
@ -3757,7 +3770,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Alias,
alias=self._parse_id_var(any_token=True),
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
this=self._match(TokenType.ALIAS) and self._parse_assignment(),
)
def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]:
@ -3791,7 +3804,7 @@ class Parser(metaclass=_Parser):
def _parse_ordered(
self, parse_method: t.Optional[t.Callable] = None
) -> t.Optional[exp.Ordered]:
this = parse_method() if parse_method else self._parse_conjunction()
this = parse_method() if parse_method else self._parse_assignment()
if not this:
return None
@ -3970,27 +3983,26 @@ class Parser(metaclass=_Parser):
return this
def _parse_expression(self) -> t.Optional[exp.Expression]:
return self._parse_alias(self._parse_conjunction())
return self._parse_alias(self._parse_assignment())
def _parse_assignment(self) -> t.Optional[exp.Expression]:
this = self._parse_disjunction()
while self._match_set(self.ASSIGNMENT):
this = self.expression(
self.ASSIGNMENT[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=self._parse_assignment(),
)
return this
def _parse_disjunction(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_conjunction, self.DISJUNCTION)
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = self._parse_equality()
if self._match(TokenType.COLON_EQ):
this = self.expression(
exp.PropertyEQ,
this=this,
comments=self._prev_comments,
expression=self._parse_conjunction(),
)
while self._match_set(self.CONJUNCTION):
this = self.expression(
self.CONJUNCTION[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=self._parse_equality(),
)
return this
return self._parse_tokens(self._parse_equality, self.CONJUNCTION)
def _parse_equality(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_comparison, self.EQUALITY)
@ -4172,12 +4184,16 @@ class Parser(metaclass=_Parser):
this = parse_method()
while self._match_set(self.FACTOR):
this = self.expression(
self.FACTOR[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=parse_method(),
)
klass = self.FACTOR[self._prev.token_type]
comments = self._prev_comments
expression = parse_method()
if not expression and klass is exp.IntDiv and self._prev.text.isalpha():
self._retreat(self._index - 1)
return this
this = self.expression(klass, this=this, comments=comments, expression=expression)
if isinstance(this, exp.Div):
this.args["typed"] = self.dialect.TYPED_DIVISION
this.args["safe"] = self.dialect.SAFE_DIVISION
@ -4291,6 +4307,29 @@ class Parser(metaclass=_Parser):
if type_token == TokenType.OBJECT_IDENTIFIER:
return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper())
# https://materialize.com/docs/sql/types/map/
if type_token == TokenType.MAP and self._match(TokenType.L_BRACKET):
key_type = self._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if not self._match(TokenType.FARROW):
self._retreat(index)
return None
value_type = self._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if not self._match(TokenType.R_BRACKET):
self._retreat(index)
return None
return exp.DataType(
this=exp.DataType.Type.MAP,
expressions=[key_type, value_type],
nested=True,
prefix=prefix,
)
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS
@ -4345,7 +4384,7 @@ class Parser(metaclass=_Parser):
self.raise_error("Expecting >")
if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)):
values = self._parse_csv(self._parse_conjunction)
values = self._parse_csv(self._parse_assignment)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
if type_token in self.TIMESTAMPS:
@ -4400,6 +4439,10 @@ class Parser(metaclass=_Parser):
elif expressions:
this.set("expressions", expressions)
# https://materialize.com/docs/sql/types/list/#type-name
while self._match(TokenType.LIST):
this = exp.DataType(this=exp.DataType.Type.LIST, expressions=[this], nested=True)
index = self._index
# Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3]
@ -4411,7 +4454,7 @@ class Parser(metaclass=_Parser):
break
matched_array = False
values = self._parse_csv(self._parse_conjunction) or None
values = self._parse_csv(self._parse_assignment) or None
if values and not schema:
self._retreat(index)
break
@ -4818,7 +4861,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.DISTINCT):
this = self.expression(
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
exp.Distinct, expressions=self._parse_csv(self._parse_assignment)
)
else:
this = self._parse_select_or_expression(alias=alias)
@ -4863,7 +4906,7 @@ class Parser(metaclass=_Parser):
constraints.append(
self.expression(
exp.ComputedColumnConstraint,
this=self._parse_conjunction(),
this=self._parse_assignment(),
persisted=persisted or self._match_text_seq("PERSISTED"),
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
)
@ -5153,7 +5196,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
return self._parse_slice(self._parse_alias(self._parse_assignment(), explicit=True))
def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
@ -5172,9 +5215,13 @@ class Parser(metaclass=_Parser):
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions))
elif not this or this.name.upper() == "ARRAY":
elif not this:
this = self.expression(exp.Array, expressions=expressions)
else:
constructor_type = self.ARRAY_CONSTRUCTORS.get(this.name.upper())
if constructor_type:
return self.expression(constructor_type, expressions=expressions)
expressions = apply_index_offset(this, expressions, -self.dialect.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
@ -5183,7 +5230,7 @@ class Parser(metaclass=_Parser):
def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return self.expression(exp.Slice, this=this, expression=self._parse_assignment())
return this
def _parse_case(self) -> t.Optional[exp.Expression]:
@ -5191,16 +5238,16 @@ class Parser(metaclass=_Parser):
default = None
comments = self._prev_comments
expression = self._parse_conjunction()
expression = self._parse_assignment()
while self._match(TokenType.WHEN):
this = self._parse_conjunction()
this = self._parse_assignment()
self._match(TokenType.THEN)
then = self._parse_conjunction()
then = self._parse_assignment()
ifs.append(self.expression(exp.If, this=this, true=then))
if self._match(TokenType.ELSE):
default = self._parse_conjunction()
default = self._parse_assignment()
if not self._match(TokenType.END):
if isinstance(default, exp.Interval) and default.this.sql().upper() == "END":
@ -5214,7 +5261,7 @@ class Parser(metaclass=_Parser):
def _parse_if(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
args = self._parse_csv(self._parse_conjunction)
args = self._parse_csv(self._parse_assignment)
this = self.validate_expression(exp.If.from_arg_list(args), args)
self._match_r_paren()
else:
@ -5223,15 +5270,15 @@ class Parser(metaclass=_Parser):
if self.NO_PAREN_IF_COMMANDS and index == 0:
return self._parse_as_command(self._prev)
condition = self._parse_conjunction()
condition = self._parse_assignment()
if not condition:
self._retreat(index)
return None
self._match(TokenType.THEN)
true = self._parse_conjunction()
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
true = self._parse_assignment()
false = self._parse_assignment() if self._match(TokenType.ELSE) else None
self._match(TokenType.END)
this = self.expression(exp.If, this=condition, true=true, false=false)
@ -5259,8 +5306,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
def _parse_gap_fill(self) -> exp.GapFill:
self._match(TokenType.TABLE)
this = self._parse_table()
self._match(TokenType.COMMA)
args = [this, *self._parse_csv(self._parse_lambda)]
gap_fill = exp.GapFill.from_arg_list(args)
return self.validate_expression(gap_fill, args)
def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
this = self._parse_conjunction()
this = self._parse_assignment()
if not self._match(TokenType.ALIAS):
if self._match(TokenType.COMMA):
@ -5313,12 +5370,12 @@ class Parser(metaclass=_Parser):
def _parse_string_agg(self) -> exp.Expression:
if self._match(TokenType.DISTINCT):
args: t.List[t.Optional[exp.Expression]] = [
self.expression(exp.Distinct, expressions=[self._parse_conjunction()])
self.expression(exp.Distinct, expressions=[self._parse_assignment()])
]
if self._match(TokenType.COMMA):
args.extend(self._parse_csv(self._parse_conjunction))
args.extend(self._parse_csv(self._parse_assignment))
else:
args = self._parse_csv(self._parse_conjunction) # type: ignore
args = self._parse_csv(self._parse_assignment) # type: ignore
index = self._index
if not self._match(TokenType.R_PAREN) and args:
@ -5365,7 +5422,7 @@ class Parser(metaclass=_Parser):
needs special treatment, since we need to explicitly check for it with `IS NULL`,
instead of relying on pattern matching.
"""
args = self._parse_csv(self._parse_conjunction)
args = self._parse_csv(self._parse_assignment)
if len(args) < 3:
return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1))
@ -5965,7 +6022,7 @@ class Parser(metaclass=_Parser):
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_set_operations(
self._parse_expression() if alias else self._parse_conjunction()
self._parse_expression() if alias else self._parse_assignment()
)
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
@ -6077,7 +6134,7 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, drop=True)
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
return self.expression(exp.AlterColumn, this=column, default=self._parse_assignment())
if self._match(TokenType.COMMENT):
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
if self._match_text_seq("DROP", "NOT", "NULL"):
@ -6100,7 +6157,7 @@ class Parser(metaclass=_Parser):
this=column,
dtype=self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
using=self._match(TokenType.USING) and self._parse_assignment(),
)
def _parse_alter_diststyle(self) -> exp.AlterDistStyle:
@ -6155,9 +6212,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN, advance=False) or self._match_text_seq(
"TABLE", "PROPERTIES"
):
alter_set.set("expressions", self._parse_wrapped_csv(self._parse_conjunction))
alter_set.set("expressions", self._parse_wrapped_csv(self._parse_assignment))
elif self._match_text_seq("FILESTREAM_ON", advance=False):
alter_set.set("expressions", [self._parse_conjunction()])
alter_set.set("expressions", [self._parse_assignment()])
elif self._match_texts(("LOGGED", "UNLOGGED")):
alter_set.set("option", exp.var(self._prev.text.upper()))
elif self._match_text_seq("WITHOUT") and self._match_texts(("CLUSTER", "OIDS")):
@ -6175,7 +6232,7 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("STAGE_COPY_OPTIONS"):
alter_set.set("copy_options", self._parse_wrapped_options())
elif self._match_text_seq("TAG") or self._match_text_seq("TAGS"):
alter_set.set("tag", self._parse_csv(self._parse_conjunction))
alter_set.set("tag", self._parse_csv(self._parse_assignment))
else:
if self._match_text_seq("SERDE"):
alter_set.set("serde", self._parse_field())
@ -6227,7 +6284,7 @@ class Parser(metaclass=_Parser):
using = self._parse_table()
self._match(TokenType.ON)
on = self._parse_conjunction()
on = self._parse_assignment()
return self.expression(
exp.Merge,
@ -6248,7 +6305,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("BY", "TARGET")
else self._match_text_seq("BY", "SOURCE")
)
condition = self._parse_conjunction() if self._match(TokenType.AND) else None
condition = self._parse_assignment() if self._match(TokenType.AND) else None
self._match(TokenType.THEN)
@ -6428,7 +6485,7 @@ class Parser(metaclass=_Parser):
self._retreat(index - 1)
return None
iterator = self._parse_column()
condition = self._parse_conjunction() if self._match_text_seq("IF") else None
condition = self._parse_assignment() if self._match_text_seq("IF") else None
return self.expression(
exp.Comprehension,
this=this,

View file

@ -294,6 +294,7 @@ class TokenType(AutoName):
LIKE = auto()
LIKE_ANY = auto()
LIMIT = auto()
LIST = auto()
LOAD = auto()
LOCK = auto()
MAP = auto()
@ -813,6 +814,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DECIMAL": TokenType.DECIMAL,
"BIGDECIMAL": TokenType.BIGDECIMAL,
"BIGNUMERIC": TokenType.BIGDECIMAL,
"LIST": TokenType.LIST,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
"NUMBER": TokenType.DECIMAL,