Merging upstream version 18.13.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
a56b8dde5c
commit
320822f1c4
76 changed files with 21248 additions and 19605 deletions
|
@ -1019,11 +1019,11 @@ def posexplode(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def explode_outer(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "EXPLODE_OUTER")
|
||||
return Column.invoke_expression_over_column(col, expression.ExplodeOuter)
|
||||
|
||||
|
||||
def posexplode_outer(col: ColumnOrName) -> Column:
|
||||
return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER")
|
||||
return Column.invoke_expression_over_column(col, expression.PosexplodeOuter)
|
||||
|
||||
|
||||
def get_json_object(col: ColumnOrName, path: str) -> Column:
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.tokens import TokenType
|
|||
class Databricks(Spark):
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**Spark.Parser.FUNCTIONS,
|
||||
|
@ -51,6 +52,8 @@ class Databricks(Spark):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.TryCast)
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
|
||||
kind = expression.args.get("kind")
|
||||
|
|
|
@ -133,6 +133,10 @@ class DuckDB(Dialect):
|
|||
"UINTEGER": TokenType.UINT,
|
||||
"USMALLINT": TokenType.USMALLINT,
|
||||
"UTINYINT": TokenType.UTINYINT,
|
||||
"TIMESTAMP_S": TokenType.TIMESTAMP_S,
|
||||
"TIMESTAMP_MS": TokenType.TIMESTAMP_MS,
|
||||
"TIMESTAMP_NS": TokenType.TIMESTAMP_NS,
|
||||
"TIMESTAMP_US": TokenType.TIMESTAMP,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -321,6 +325,9 @@ class DuckDB(Dialect):
|
|||
exp.DataType.Type.UINT: "UINTEGER",
|
||||
exp.DataType.Type.VARBINARY: "BLOB",
|
||||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S",
|
||||
exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS",
|
||||
exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
|
||||
}
|
||||
|
||||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
|
|
@ -82,7 +82,6 @@ class Oracle(Dialect):
|
|||
this=self._parse_format_json(self._parse_bitwise()),
|
||||
order=self._parse_order(),
|
||||
),
|
||||
"JSON_TABLE": lambda self: self._parse_json_table(),
|
||||
"XMLTABLE": _parse_xml_table,
|
||||
}
|
||||
|
||||
|
@ -96,29 +95,6 @@ class Oracle(Dialect):
|
|||
# Reference: https://stackoverflow.com/a/336455
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
|
||||
|
||||
# Note: this is currently incomplete; it only implements the "JSON_value_column" part
|
||||
def _parse_json_column_def(self) -> exp.JSONColumnDef:
|
||||
this = self._parse_id_var()
|
||||
kind = self._parse_types(allow_identifiers=False)
|
||||
path = self._match_text_seq("PATH") and self._parse_string()
|
||||
return self.expression(exp.JSONColumnDef, this=this, kind=kind, path=path)
|
||||
|
||||
def _parse_json_table(self) -> exp.JSONTable:
|
||||
this = self._parse_format_json(self._parse_bitwise())
|
||||
path = self._match(TokenType.COMMA) and self._parse_string()
|
||||
error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
|
||||
empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
|
||||
self._match(TokenType.COLUMN)
|
||||
expressions = self._parse_wrapped_csv(self._parse_json_column_def, optional=True)
|
||||
|
||||
return exp.JSONTable(
|
||||
this=this,
|
||||
expressions=expressions,
|
||||
path=path,
|
||||
error_handling=error_handling,
|
||||
empty_handling=empty_handling,
|
||||
)
|
||||
|
||||
def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E:
|
||||
return self.expression(
|
||||
expr_type,
|
||||
|
|
|
@ -34,7 +34,7 @@ def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct)
|
|||
|
||||
|
||||
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Join(
|
||||
|
|
|
@ -58,6 +58,11 @@ class Redshift(Postgres):
|
|||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"APPROXIMATE": lambda self: self._parse_approximate_count(),
|
||||
}
|
||||
|
||||
def _parse_table(
|
||||
self,
|
||||
schema: bool = False,
|
||||
|
@ -93,11 +98,22 @@ class Redshift(Postgres):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||
def _parse_convert(
|
||||
self, strict: bool, safe: t.Optional[bool] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_bitwise()
|
||||
return self.expression(exp.TryCast, this=this, to=to)
|
||||
return self.expression(exp.TryCast, this=this, to=to, safe=safe)
|
||||
|
||||
def _parse_approximate_count(self) -> t.Optional[exp.ApproxDistinct]:
|
||||
index = self._index - 1
|
||||
func = self._parse_function()
|
||||
|
||||
if isinstance(func, exp.Count) and isinstance(func.this, exp.Distinct):
|
||||
return self.expression(exp.ApproxDistinct, this=seq_get(func.this.expressions, 0))
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
|
@ -144,6 +160,7 @@ class Redshift(Postgres):
|
|||
**Postgres.Generator.TRANSFORMS,
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.ConcatWs: concat_ws_to_dpipe_sql,
|
||||
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
|
|
|
@ -76,6 +76,9 @@ class Spark(Spark2):
|
|||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
),
|
||||
exp.TryCast: lambda self, e: self.trycast_sql(e)
|
||||
if e.args.get("safe")
|
||||
else self.cast_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
|
|
|
@ -477,7 +477,9 @@ class TSQL(Dialect):
|
|||
returns.set("table", table)
|
||||
return returns
|
||||
|
||||
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||
def _parse_convert(
|
||||
self, strict: bool, safe: t.Optional[bool] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_conjunction()
|
||||
|
@ -513,12 +515,13 @@ class TSQL(Dialect):
|
|||
exp.Cast if strict else exp.TryCast,
|
||||
to=to,
|
||||
this=self.expression(exp.TimeToStr, this=this, format=format_norm),
|
||||
safe=safe,
|
||||
)
|
||||
elif to.this == DataType.Type.TEXT:
|
||||
return self.expression(exp.TimeToStr, this=this, format=format_norm)
|
||||
|
||||
# Entails a simple cast without any format requirement
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
|
||||
|
||||
def _parse_user_defined_function(
|
||||
self, kind: t.Optional[TokenType] = None
|
||||
|
|
|
@ -105,7 +105,7 @@ class RowReader:
|
|||
return self.row[self.columns[column]]
|
||||
|
||||
|
||||
class Tables(AbstractMappingSchema[Table]):
|
||||
class Tables(AbstractMappingSchema):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -487,7 +487,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
|
||||
if not type(node) is self.__class__:
|
||||
yield node.unnest() if unnest else node
|
||||
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.sql()
|
||||
|
@ -2107,7 +2107,7 @@ class LockingProperty(Property):
|
|||
arg_types = {
|
||||
"this": False,
|
||||
"kind": True,
|
||||
"for_or_in": True,
|
||||
"for_or_in": False,
|
||||
"lock_type": True,
|
||||
"override": False,
|
||||
}
|
||||
|
@ -3605,6 +3605,9 @@ class DataType(Expression):
|
|||
TIMESTAMP = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMP_S = auto()
|
||||
TIMESTAMP_MS = auto()
|
||||
TIMESTAMP_NS = auto()
|
||||
TINYINT = auto()
|
||||
TSMULTIRANGE = auto()
|
||||
TSRANGE = auto()
|
||||
|
@ -3661,6 +3664,9 @@ class DataType(Expression):
|
|||
Type.TIMESTAMP,
|
||||
Type.TIMESTAMPTZ,
|
||||
Type.TIMESTAMPLTZ,
|
||||
Type.TIMESTAMP_S,
|
||||
Type.TIMESTAMP_MS,
|
||||
Type.TIMESTAMP_NS,
|
||||
Type.DATE,
|
||||
Type.DATETIME,
|
||||
Type.DATETIME64,
|
||||
|
@ -4286,7 +4292,7 @@ class Case(Func):
|
|||
|
||||
|
||||
class Cast(Func):
|
||||
arg_types = {"this": True, "to": True, "format": False}
|
||||
arg_types = {"this": True, "to": True, "format": False, "safe": False}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -4538,6 +4544,18 @@ class Explode(Func):
|
|||
pass
|
||||
|
||||
|
||||
class ExplodeOuter(Explode):
|
||||
pass
|
||||
|
||||
|
||||
class Posexplode(Explode):
|
||||
pass
|
||||
|
||||
|
||||
class PosexplodeOuter(Posexplode):
|
||||
pass
|
||||
|
||||
|
||||
class Floor(Func):
|
||||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
@ -4621,14 +4639,18 @@ class JSONArrayAgg(Func):
|
|||
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
|
||||
# Note: parsing of JSON column definitions is currently incomplete.
|
||||
class JSONColumnDef(Expression):
|
||||
arg_types = {"this": True, "kind": False, "path": False}
|
||||
arg_types = {"this": False, "kind": False, "path": False, "nested_schema": False}
|
||||
|
||||
|
||||
class JSONSchema(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
# # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
|
||||
class JSONTable(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expressions": True,
|
||||
"schema": True,
|
||||
"path": False,
|
||||
"error_handling": False,
|
||||
"empty_handling": False,
|
||||
|
@ -4790,10 +4812,6 @@ class Nvl2(Func):
|
|||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
||||
class Posexplode(Func):
|
||||
pass
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function
|
||||
class Predict(Func):
|
||||
arg_types = {"this": True, "expression": True, "params_struct": False}
|
||||
|
|
|
@ -1226,9 +1226,10 @@ class Generator:
|
|||
kind = expression.args.get("kind")
|
||||
this = f" {self.sql(expression, 'this')}" if expression.this else ""
|
||||
for_or_in = expression.args.get("for_or_in")
|
||||
for_or_in = f" {for_or_in}" if for_or_in else ""
|
||||
lock_type = expression.args.get("lock_type")
|
||||
override = " OVERRIDE" if expression.args.get("override") else ""
|
||||
return f"LOCKING {kind}{this} {for_or_in} {lock_type}{override}"
|
||||
return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}"
|
||||
|
||||
def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str:
|
||||
data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA"
|
||||
|
@ -2179,13 +2180,21 @@ class Generator:
|
|||
)
|
||||
|
||||
def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str:
|
||||
path = self.sql(expression, "path")
|
||||
path = f" PATH {path}" if path else ""
|
||||
nested_schema = self.sql(expression, "nested_schema")
|
||||
|
||||
if nested_schema:
|
||||
return f"NESTED{path} {nested_schema}"
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f" {kind}" if kind else ""
|
||||
path = self.sql(expression, "path")
|
||||
path = f" PATH {path}" if path else ""
|
||||
return f"{this}{kind}{path}"
|
||||
|
||||
def jsonschema_sql(self, expression: exp.JSONSchema) -> str:
|
||||
return self.func("COLUMNS", *expression.expressions)
|
||||
|
||||
def jsontable_sql(self, expression: exp.JSONTable) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
path = self.sql(expression, "path")
|
||||
|
@ -2194,9 +2203,9 @@ class Generator:
|
|||
error_handling = f" {error_handling}" if error_handling else ""
|
||||
empty_handling = expression.args.get("empty_handling")
|
||||
empty_handling = f" {empty_handling}" if empty_handling else ""
|
||||
columns = f" COLUMNS ({self.expressions(expression, skip_first=True)})"
|
||||
schema = self.sql(expression, "schema")
|
||||
return self.func(
|
||||
"JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling}{columns})"
|
||||
"JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling} {schema})"
|
||||
)
|
||||
|
||||
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
|
||||
|
|
|
@ -441,6 +441,14 @@ def first(it: t.Iterable[T]) -> T:
|
|||
|
||||
|
||||
def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
|
||||
"""
|
||||
Merges a sequence of ranges, represented as tuples (low, high) whose values
|
||||
belong to some totally-ordered set.
|
||||
|
||||
Example:
|
||||
>>> merge_ranges([(1, 3), (2, 6)])
|
||||
[(1, 6)]
|
||||
"""
|
||||
if not ranges:
|
||||
return []
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot import exp
|
|||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
@ -63,15 +64,33 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
return expression
|
||||
|
||||
|
||||
def normalized(expression, dnf=False):
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
|
||||
return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
|
||||
|
||||
|
||||
def normalization_distance(expression, dnf=False):
|
||||
def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
|
||||
"""
|
||||
The difference in the number of predicates between the current expression and the normalized form.
|
||||
Checks whether a given expression is in a normal form of interest.
|
||||
|
||||
Example:
|
||||
>>> from sqlglot import parse_one
|
||||
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
|
||||
True
|
||||
>>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default
|
||||
True
|
||||
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
|
||||
False
|
||||
|
||||
Args:
|
||||
expression: The expression to check if it's normalized.
|
||||
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
||||
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
||||
"""
|
||||
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
|
||||
return not any(
|
||||
connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
|
||||
)
|
||||
|
||||
|
||||
def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
|
||||
"""
|
||||
The difference in the number of predicates between a given expression and its normalized form.
|
||||
|
||||
This is used as an estimate of the cost of the conversion which is exponential in complexity.
|
||||
|
||||
|
@ -82,10 +101,12 @@ def normalization_distance(expression, dnf=False):
|
|||
4
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to compute distance
|
||||
dnf (bool): compute to dnf distance instead
|
||||
expression: The expression to compute the normalization distance for.
|
||||
dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
|
||||
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
|
||||
|
||||
Returns:
|
||||
int: difference
|
||||
The normalization distance.
|
||||
"""
|
||||
return sum(_predicate_lengths(expression, dnf)) - (
|
||||
sum(1 for _ in expression.find_all(exp.Connector)) + 1
|
||||
|
|
|
@ -39,10 +39,14 @@ def optimize_joins(expression):
|
|||
if len(other_table_names(dep)) < 2:
|
||||
continue
|
||||
|
||||
operator = type(on)
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.true())
|
||||
join.on(predicate, copy=False)
|
||||
predicate = exp._combine(
|
||||
[join.args.get("on"), predicate], operator, copy=False
|
||||
)
|
||||
join.on(predicate, append=False, copy=False)
|
||||
|
||||
expression = reorder_joins(expression)
|
||||
expression = normalize(expression)
|
||||
|
|
|
@ -9,7 +9,9 @@ from sqlglot.schema import ensure_schema
|
|||
SELECT_ALL = object()
|
||||
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||
DEFAULT_SELECTION = lambda is_agg: alias(
|
||||
exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
|
||||
)
|
||||
|
||||
|
||||
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
|
||||
|
@ -98,6 +100,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
new_selections = []
|
||||
removed = False
|
||||
star = False
|
||||
is_agg = False
|
||||
|
||||
select_all = SELECT_ALL in parent_selections
|
||||
|
||||
|
@ -112,6 +115,9 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
star = True
|
||||
removed = True
|
||||
|
||||
if not is_agg and selection.find(exp.AggFunc):
|
||||
is_agg = True
|
||||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
names = {s.alias_or_name for s in new_selections}
|
||||
|
@ -124,7 +130,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
new_selections.append(DEFAULT_SELECTION(is_agg))
|
||||
|
||||
scope.expression.select(*new_selections, append=False, copy=False)
|
||||
|
||||
|
|
|
@ -137,8 +137,8 @@ class Scope:
|
|||
if not self._collected:
|
||||
self._collect()
|
||||
|
||||
def walk(self, bfs=True):
|
||||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
def walk(self, bfs=True, prune=None):
|
||||
return walk_in_scope(self.expression, bfs=bfs, prune=None)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
return find_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
@ -731,7 +731,7 @@ def _traverse_ddl(scope):
|
|||
yield from _traverse_scope(query_scope)
|
||||
|
||||
|
||||
def walk_in_scope(expression, bfs=True):
|
||||
def walk_in_scope(expression, bfs=True, prune=None):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in the syntrax tree, stopping at
|
||||
nodes that start child scopes.
|
||||
|
@ -740,16 +740,20 @@ def walk_in_scope(expression, bfs=True):
|
|||
expression (exp.Expression):
|
||||
bfs (bool): if set to True the BFS traversal order will be applied,
|
||||
otherwise the DFS traversal will be used instead.
|
||||
prune ((node, parent, arg_key) -> bool): callable that returns True if
|
||||
the generator should stop traversing this branch of the tree.
|
||||
|
||||
Yields:
|
||||
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
|
||||
"""
|
||||
# We'll use this variable to pass state into the dfs generator.
|
||||
# Whenever we set it to True, we exclude a subtree from traversal.
|
||||
prune = False
|
||||
crossed_scope_boundary = False
|
||||
|
||||
for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
|
||||
prune = False
|
||||
for node, parent, key in expression.walk(
|
||||
bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args))
|
||||
):
|
||||
crossed_scope_boundary = False
|
||||
|
||||
yield node, parent, key
|
||||
|
||||
|
@ -765,7 +769,7 @@ def walk_in_scope(expression, bfs=True):
|
|||
or isinstance(node, exp.UDTF)
|
||||
or isinstance(node, exp.Subqueryable)
|
||||
):
|
||||
prune = True
|
||||
crossed_scope_boundary = True
|
||||
|
||||
if isinstance(node, (exp.Subquery, exp.UDTF)):
|
||||
# The following args are not actually in the inner scope, so we should visit them
|
||||
|
|
|
@ -5,9 +5,11 @@ import typing as t
|
|||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
@ -17,7 +19,7 @@ class UnsupportedUnit(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
def simplify(expression, constant_propagation=False):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
|
@ -29,6 +31,8 @@ def simplify(expression):
|
|||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to simplify
|
||||
constant_propagation: whether or not the constant propagation rule should be used
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
@ -67,13 +71,16 @@ def simplify(expression):
|
|||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
|
||||
if constant_propagation:
|
||||
node = propagate_constants(node, root)
|
||||
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
|
||||
# Post-order transformations
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node = remove_complements(node, root)
|
||||
node = simplify_coalesce(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
|
@ -287,19 +294,19 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return None
|
||||
|
||||
|
||||
def remove_compliments(expression, root=True):
|
||||
def remove_complements(expression, root=True):
|
||||
"""
|
||||
Removing compliments.
|
||||
Removing complements.
|
||||
|
||||
A AND NOT A -> FALSE
|
||||
A OR NOT A -> TRUE
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
complement = exp.false() if isinstance(expression, exp.And) else exp.true()
|
||||
|
||||
for a, b in itertools.permutations(expression.flatten(), 2):
|
||||
if is_complement(a, b):
|
||||
return compliment
|
||||
return complement
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -369,6 +376,51 @@ def absorb_and_eliminate(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
def propagate_constants(expression, root=True):
|
||||
"""
|
||||
Propagate constants for conjunctions in DNF:
|
||||
|
||||
SELECT * FROM t WHERE a = b AND b = 5 becomes
|
||||
SELECT * FROM t WHERE a = 5 AND b = 5
|
||||
|
||||
Reference: https://www.sqlite.org/optoverview.html
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(expression, exp.And)
|
||||
and (root or not expression.same_parent)
|
||||
and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
|
||||
):
|
||||
constant_mapping = {}
|
||||
for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
|
||||
if isinstance(expr, exp.EQ):
|
||||
l, r = expr.left, expr.right
|
||||
|
||||
# TODO: create a helper that can be used to detect nested literal expressions such
|
||||
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
|
||||
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
|
||||
pass
|
||||
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
|
||||
l, r = r, l
|
||||
else:
|
||||
continue
|
||||
|
||||
constant_mapping[l] = (id(l), r)
|
||||
|
||||
if constant_mapping:
|
||||
for column in find_all_in_scope(expression, exp.Column):
|
||||
parent = column.parent
|
||||
column_id, constant = constant_mapping.get(column) or (None, None)
|
||||
if (
|
||||
column_id is not None
|
||||
and id(column) != column_id
|
||||
and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
|
||||
):
|
||||
column.replace(constant.copy())
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
||||
exp.DateAdd: exp.Sub,
|
||||
exp.DateSub: exp.Add,
|
||||
|
@ -609,21 +661,38 @@ SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
|
|||
|
||||
def simplify_concat(expression):
|
||||
"""Reduces all groups that contain string literals by concatenating them."""
|
||||
if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
|
||||
if not isinstance(expression, CONCATS) or (
|
||||
# We can't reduce a CONCAT_WS call if we don't statically know the separator
|
||||
isinstance(expression, exp.ConcatWs)
|
||||
and not expression.expressions[0].is_string
|
||||
):
|
||||
return expression
|
||||
|
||||
if isinstance(expression, exp.ConcatWs):
|
||||
sep_expr, *expressions = expression.expressions
|
||||
sep = sep_expr.name
|
||||
concat_type = exp.ConcatWs
|
||||
else:
|
||||
expressions = expression.expressions
|
||||
sep = ""
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
|
||||
new_args = []
|
||||
for is_string_group, group in itertools.groupby(
|
||||
expression.expressions or expression.flatten(), lambda e: e.is_string
|
||||
expressions or expression.flatten(), lambda e: e.is_string
|
||||
):
|
||||
if is_string_group:
|
||||
new_args.append(exp.Literal.string("".join(string.name for string in group)))
|
||||
new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
|
||||
else:
|
||||
new_args.extend(group)
|
||||
|
||||
# Ensures we preserve the right concat type, i.e. whether it's "safe" or not
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
|
||||
if len(new_args) == 1 and new_args[0].is_string:
|
||||
return new_args[0]
|
||||
|
||||
if concat_type is exp.ConcatWs:
|
||||
new_args = [sep_expr] + new_args
|
||||
|
||||
return concat_type(expressions=new_args)
|
||||
|
||||
|
||||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
|
|
@ -160,6 +160,9 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TIME,
|
||||
TokenType.TIMETZ,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMP_S,
|
||||
TokenType.TIMESTAMP_MS,
|
||||
TokenType.TIMESTAMP_NS,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.TIMESTAMPLTZ,
|
||||
TokenType.DATETIME,
|
||||
|
@ -792,17 +795,18 @@ class Parser(metaclass=_Parser):
|
|||
"DECODE": lambda self: self._parse_decode(),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"JSON_OBJECT": lambda self: self._parse_json_object(),
|
||||
"JSON_TABLE": lambda self: self._parse_json_table(),
|
||||
"LOG": lambda self: self._parse_logarithm(),
|
||||
"MATCH": lambda self: self._parse_match_against(),
|
||||
"OPENJSON": lambda self: self._parse_open_json(),
|
||||
"POSITION": lambda self: self._parse_position(),
|
||||
"PREDICT": lambda self: self._parse_predict(),
|
||||
"SAFE_CAST": lambda self: self._parse_cast(False),
|
||||
"SAFE_CAST": lambda self: self._parse_cast(False, safe=True),
|
||||
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||
"SUBSTRING": lambda self: self._parse_substring(),
|
||||
"TRIM": lambda self: self._parse_trim(),
|
||||
"TRY_CAST": lambda self: self._parse_cast(False),
|
||||
"TRY_CONVERT": lambda self: self._parse_convert(False),
|
||||
"TRY_CAST": lambda self: self._parse_cast(False, safe=True),
|
||||
"TRY_CONVERT": lambda self: self._parse_convert(False, safe=True),
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
|
@ -4135,7 +4139,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
|
||||
|
||||
def _parse_cast(self, strict: bool) -> exp.Expression:
|
||||
def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
|
||||
this = self._parse_conjunction()
|
||||
|
||||
if not self._match(TokenType.ALIAS):
|
||||
|
@ -4176,7 +4180,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
|
||||
return self.expression(
|
||||
exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe
|
||||
)
|
||||
|
||||
def _parse_concat(self) -> t.Optional[exp.Expression]:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
|
@ -4230,7 +4236,9 @@ class Parser(metaclass=_Parser):
|
|||
order = self._parse_order(this=seq_get(args, 0))
|
||||
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
|
||||
|
||||
def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
|
||||
def _parse_convert(
|
||||
self, strict: bool, safe: t.Optional[bool] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bitwise()
|
||||
|
||||
if self._match(TokenType.USING):
|
||||
|
@ -4242,7 +4250,7 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
to = None
|
||||
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
|
||||
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
|
||||
|
||||
def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
|
||||
"""
|
||||
|
@ -4347,6 +4355,50 @@ class Parser(metaclass=_Parser):
|
|||
encoding=encoding,
|
||||
)
|
||||
|
||||
# Note: this is currently incomplete; it only implements the "JSON_value_column" part
|
||||
def _parse_json_column_def(self) -> exp.JSONColumnDef:
|
||||
if not self._match_text_seq("NESTED"):
|
||||
this = self._parse_id_var()
|
||||
kind = self._parse_types(allow_identifiers=False)
|
||||
nested = None
|
||||
else:
|
||||
this = None
|
||||
kind = None
|
||||
nested = True
|
||||
|
||||
path = self._match_text_seq("PATH") and self._parse_string()
|
||||
nested_schema = nested and self._parse_json_schema()
|
||||
|
||||
return self.expression(
|
||||
exp.JSONColumnDef,
|
||||
this=this,
|
||||
kind=kind,
|
||||
path=path,
|
||||
nested_schema=nested_schema,
|
||||
)
|
||||
|
||||
def _parse_json_schema(self) -> exp.JSONSchema:
|
||||
self._match_text_seq("COLUMNS")
|
||||
return self.expression(
|
||||
exp.JSONSchema,
|
||||
expressions=self._parse_wrapped_csv(self._parse_json_column_def, optional=True),
|
||||
)
|
||||
|
||||
def _parse_json_table(self) -> exp.JSONTable:
|
||||
this = self._parse_format_json(self._parse_bitwise())
|
||||
path = self._match(TokenType.COMMA) and self._parse_string()
|
||||
error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL")
|
||||
empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL")
|
||||
schema = self._parse_json_schema()
|
||||
|
||||
return exp.JSONTable(
|
||||
this=this,
|
||||
schema=schema,
|
||||
path=path,
|
||||
error_handling=error_handling,
|
||||
empty_handling=empty_handling,
|
||||
)
|
||||
|
||||
def _parse_logarithm(self) -> exp.Func:
|
||||
# Default argument order is base, expression
|
||||
args = self._parse_csv(self._parse_range)
|
||||
|
@ -4973,7 +5025,17 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.ON)
|
||||
on = self._parse_conjunction()
|
||||
|
||||
return self.expression(
|
||||
exp.Merge,
|
||||
this=target,
|
||||
using=using,
|
||||
on=on,
|
||||
expressions=self._parse_when_matched(),
|
||||
)
|
||||
|
||||
def _parse_when_matched(self) -> t.List[exp.When]:
|
||||
whens = []
|
||||
|
||||
while self._match(TokenType.WHEN):
|
||||
matched = not self._match(TokenType.NOT)
|
||||
self._match_text_seq("MATCHED")
|
||||
|
@ -5020,14 +5082,7 @@ class Parser(metaclass=_Parser):
|
|||
then=then,
|
||||
)
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.Merge,
|
||||
this=target,
|
||||
using=using,
|
||||
on=on,
|
||||
expressions=whens,
|
||||
)
|
||||
return whens
|
||||
|
||||
def _parse_show(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
|
||||
|
|
|
@ -5,7 +5,6 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot._typing import T
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.errors import ParseError, SchemaError
|
||||
from sqlglot.helper import dict_depth
|
||||
|
@ -71,7 +70,7 @@ class Schema(abc.ABC):
|
|||
def get_column_type(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
column: exp.Column | str,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.DataType:
|
||||
|
@ -88,6 +87,28 @@ class Schema(abc.ABC):
|
|||
The resulting column type.
|
||||
"""
|
||||
|
||||
def has_column(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column | str,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether or not `column` appears in `table`'s schema.
|
||||
|
||||
Args:
|
||||
table: the source table.
|
||||
column: the target column.
|
||||
dialect: the SQL dialect that will be used to parse `table` if it's a string.
|
||||
normalize: whether to normalize identifiers according to the dialect of interest.
|
||||
|
||||
Returns:
|
||||
True if the column appears in the schema, False otherwise.
|
||||
"""
|
||||
name = column if isinstance(column, str) else column.name
|
||||
return name in self.column_names(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supported_table_args(self) -> t.Tuple[str, ...]:
|
||||
|
@ -101,7 +122,7 @@ class Schema(abc.ABC):
|
|||
return True
|
||||
|
||||
|
||||
class AbstractMappingSchema(t.Generic[T]):
|
||||
class AbstractMappingSchema:
|
||||
def __init__(
|
||||
self,
|
||||
mapping: t.Optional[t.Dict] = None,
|
||||
|
@ -140,7 +161,7 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
|
||||
def find(
|
||||
self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
|
||||
) -> t.Optional[T]:
|
||||
) -> t.Optional[t.Any]:
|
||||
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
|
||||
value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
|
||||
|
||||
|
@ -170,7 +191,7 @@ class AbstractMappingSchema(t.Generic[T]):
|
|||
)
|
||||
|
||||
|
||||
class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||
class MappingSchema(AbstractMappingSchema, Schema):
|
||||
"""
|
||||
Schema based on a nested mapping.
|
||||
|
||||
|
@ -287,7 +308,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
def get_column_type(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column,
|
||||
column: exp.Column | str,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> exp.DataType:
|
||||
|
@ -304,10 +325,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
if isinstance(column_type, exp.DataType):
|
||||
return column_type
|
||||
elif isinstance(column_type, str):
|
||||
return self._to_data_type(column_type.upper(), dialect=dialect)
|
||||
return self._to_data_type(column_type, dialect=dialect)
|
||||
|
||||
return exp.DataType.build("unknown")
|
||||
|
||||
def has_column(
|
||||
self,
|
||||
table: exp.Table | str,
|
||||
column: exp.Column | str,
|
||||
dialect: DialectType = None,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> bool:
|
||||
normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
|
||||
|
||||
normalized_column_name = self._normalize_name(
|
||||
column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
|
||||
)
|
||||
|
||||
table_schema = self.find(normalized_table, raise_on_missing=False)
|
||||
return normalized_column_name in table_schema if table_schema else False
|
||||
|
||||
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||
"""
|
||||
Normalizes all identifiers in the schema.
|
||||
|
|
|
@ -121,6 +121,9 @@ class TokenType(AutoName):
|
|||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
TIMESTAMP_S = auto()
|
||||
TIMESTAMP_MS = auto()
|
||||
TIMESTAMP_NS = auto()
|
||||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
DATE = auto()
|
||||
|
|
|
@ -189,9 +189,9 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
|
||||
# we use list here because expression.selects is mutated inside the loop
|
||||
for select in expression.selects.copy():
|
||||
explode = select.find(exp.Explode, exp.Posexplode)
|
||||
explode = select.find(exp.Explode)
|
||||
|
||||
if isinstance(explode, (exp.Explode, exp.Posexplode)):
|
||||
if explode:
|
||||
pos_alias = ""
|
||||
explode_alias = ""
|
||||
|
||||
|
@ -204,7 +204,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
|
|||
alias = select.replace(exp.alias_(select.this, "", copy=False))
|
||||
else:
|
||||
alias = select.replace(exp.alias_(select, ""))
|
||||
explode = alias.find(exp.Explode, exp.Posexplode)
|
||||
explode = alias.find(exp.Explode)
|
||||
assert explode
|
||||
|
||||
is_posexplode = isinstance(explode, exp.Posexplode)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue