Merging upstream version 19.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
348b067e1b
commit
89acb78953
91 changed files with 45416 additions and 43096 deletions
|
@ -4,7 +4,6 @@ import logging
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
|
||||
|
@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
Returns:
|
||||
sqlglot.Expression: normalized expression
|
||||
"""
|
||||
generate = cached_generator()
|
||||
|
||||
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
|
||||
if isinstance(node, exp.Connector):
|
||||
if normalized(node, dnf=dnf):
|
||||
|
@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
|
|||
|
||||
try:
|
||||
node = node.replace(
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
|
||||
)
|
||||
except OptimizeError as e:
|
||||
logger.info(e)
|
||||
|
@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
|
|||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
def distributive_law(expression, dnf, max_distance, generate):
|
||||
def distributive_law(expression, dnf, max_distance):
|
||||
"""
|
||||
x OR (y AND z) -> (x OR y) AND (x OR z)
|
||||
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
|
||||
|
@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
|
|||
if distance > max_distance:
|
||||
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
|
||||
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
|
||||
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
|
||||
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
|
||||
|
||||
if isinstance(expression, from_exp):
|
||||
|
@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
|
|||
|
||||
if isinstance(a, to_exp) and isinstance(b, to_exp):
|
||||
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(a, to_exp):
|
||||
return _distribute(b, a, from_func, to_func, generate)
|
||||
return _distribute(b, a, from_func, to_func)
|
||||
if isinstance(b, to_exp):
|
||||
return _distribute(a, b, from_func, to_func, generate)
|
||||
return _distribute(a, b, from_func, to_func)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _distribute(a, b, from_func, to_func, generate):
|
||||
def _distribute(a, b, from_func, to_func):
|
||||
if isinstance(a, exp.Connector):
|
||||
exp.replace_children(
|
||||
a,
|
||||
lambda c: to_func(
|
||||
uniq_sort(flatten(from_func(c, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(c, b.right)), generate),
|
||||
uniq_sort(flatten(from_func(c, b.left))),
|
||||
uniq_sort(flatten(from_func(c, b.right))),
|
||||
copy=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
a = to_func(
|
||||
uniq_sort(flatten(from_func(a, b.left)), generate),
|
||||
uniq_sort(flatten(from_func(a, b.right)), generate),
|
||||
uniq_sort(flatten(from_func(a, b.left))),
|
||||
uniq_sort(flatten(from_func(a, b.right))),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, parse_one
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.dialects.dialect import Dialect, DialectType
|
||||
|
||||
|
@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
|
|||
The transformed expression.
|
||||
"""
|
||||
if isinstance(expression, str):
|
||||
expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
|
||||
expression = exp.parse_identifier(expression, dialect=dialect)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def qualify_tables(
|
|||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
if not source.args.get("catalog"):
|
||||
if not source.args.get("catalog") and source.args.get("db"):
|
||||
source.set("catalog", exp.to_identifier(catalog))
|
||||
|
||||
if not source.alias:
|
||||
|
|
|
@ -7,8 +7,7 @@ from decimal import Decimal
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
|
@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
generate = cached_generator()
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
|
@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
|
|||
# Pre-order transformations
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, generate, root)
|
||||
node = uniq_sort(node, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
node = simplify_conditionals(node)
|
||||
|
@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
def uniq_sort(expression, generate, root=True):
|
||||
def uniq_sort(expression, root=True):
|
||||
"""
|
||||
Uniq and sort a connector.
|
||||
|
||||
|
@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
|
|||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {generate(e): e for e in flattened}
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
|
@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
|
|||
lambda a, b: expression.__class__(this=a, expression=b), operands
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def gen(expression: t.Any) -> str:
|
||||
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
|
||||
|
||||
Sorting and deduping sql is a necessary step for optimization. Calling the actual
|
||||
generator is expensive so we have a bare minimum sql generator here.
|
||||
"""
|
||||
if expression is None:
|
||||
return "_"
|
||||
if is_iterable(expression):
|
||||
return ",".join(gen(e) for e in expression)
|
||||
if not isinstance(expression, exp.Expression):
|
||||
return str(expression)
|
||||
|
||||
etype = type(expression)
|
||||
if etype in GEN_MAP:
|
||||
return GEN_MAP[etype](expression)
|
||||
return f"{expression.key} {gen(expression.args.values())}"
|
||||
|
||||
|
||||
GEN_MAP = {
|
||||
exp.Add: lambda e: _binary(e, "+"),
|
||||
exp.And: lambda e: _binary(e, "AND"),
|
||||
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
|
||||
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
|
||||
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
|
||||
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
|
||||
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
|
||||
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
|
||||
exp.Div: lambda e: _binary(e, "/"),
|
||||
exp.Dot: lambda e: _binary(e, "."),
|
||||
exp.DPipe: lambda e: _binary(e, "||"),
|
||||
exp.SafeDPipe: lambda e: _binary(e, "||"),
|
||||
exp.EQ: lambda e: _binary(e, "="),
|
||||
exp.GT: lambda e: _binary(e, ">"),
|
||||
exp.GTE: lambda e: _binary(e, ">="),
|
||||
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
|
||||
exp.ILike: lambda e: _binary(e, "ILIKE"),
|
||||
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
|
||||
exp.Is: lambda e: _binary(e, "IS"),
|
||||
exp.Like: lambda e: _binary(e, "LIKE"),
|
||||
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
|
||||
exp.LT: lambda e: _binary(e, "<"),
|
||||
exp.LTE: lambda e: _binary(e, "<="),
|
||||
exp.Mod: lambda e: _binary(e, "%"),
|
||||
exp.Mul: lambda e: _binary(e, "*"),
|
||||
exp.Neg: lambda e: _unary(e, "-"),
|
||||
exp.NEQ: lambda e: _binary(e, "<>"),
|
||||
exp.Not: lambda e: _unary(e, "NOT"),
|
||||
exp.Null: lambda e: "NULL",
|
||||
exp.Or: lambda e: _binary(e, "OR"),
|
||||
exp.Paren: lambda e: f"({gen(e.this)})",
|
||||
exp.Sub: lambda e: _binary(e, "-"),
|
||||
exp.Subquery: lambda e: f"({gen(e.args.values())})",
|
||||
exp.Table: lambda e: gen(e.args.values()),
|
||||
exp.Var: lambda e: e.name,
|
||||
}
|
||||
|
||||
|
||||
def _binary(e: exp.Binary, op: str) -> str:
|
||||
return f"{gen(e.left)} {op} {gen(e.right)}"
|
||||
|
||||
|
||||
def _unary(e: exp.Unary, op: str) -> str:
|
||||
return f"{op} {gen(e.this)}"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue