1
0
Fork 0

Merging upstream version 19.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:09 +01:00
parent 348b067e1b
commit 89acb78953
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
91 changed files with 45416 additions and 43096 deletions

View file

@ -4,7 +4,6 @@ import logging
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
try:
node = node.replace(
while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
)
except OptimizeError as e:
logger.info(e)
@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance, generate):
def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
return _distribute(a, b, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func, generate)
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func, generate)
return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func, generate)
return _distribute(a, b, from_func, to_func)
return expression
def _distribute(a, b, from_func, to_func, generate):
def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), generate),
uniq_sort(flatten(from_func(c, b.right)), generate),
uniq_sort(flatten(from_func(c, b.left))),
uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), generate),
uniq_sort(flatten(from_func(a, b.right)), generate),
uniq_sort(flatten(from_func(a, b.left))),
uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, parse_one
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@ -49,7 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
expression = exp.parse_identifier(expression, dialect=dialect)
dialect = Dialect.get_or_raise(dialect)

View file

@ -62,7 +62,7 @@ def qualify_tables(
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:

View file

@ -7,8 +7,7 @@ from decimal import Decimal
import sqlglot
from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, merge_ranges, while_changing
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
# Final means that an expression should not be simplified
@ -37,8 +36,6 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
generate = cached_generator()
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
@ -67,7 +64,7 @@ def simplify(expression, constant_propagation=False):
# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)
@ -311,7 +308,7 @@ def remove_complements(expression, root=True):
return expression
def uniq_sort(expression, generate, root=True):
def uniq_sort(expression, root=True):
"""
Uniq and sort a connector.
@ -320,7 +317,7 @@ def uniq_sort(expression, generate, root=True):
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {generate(e): e for e in flattened}
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
@ -1070,3 +1067,69 @@ def _flat_simplify(expression, simplifier, root=True):
lambda a, b: expression.__class__(this=a, expression=b), operands
)
return expression
def gen(expression: t.Any) -> str:
"""Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual
generator is expensive so we have a bare minimum sql generator here.
"""
if expression is None:
return "_"
if is_iterable(expression):
return ",".join(gen(e) for e in expression)
if not isinstance(expression, exp.Expression):
return str(expression)
etype = type(expression)
if etype in GEN_MAP:
return GEN_MAP[etype](expression)
return f"{expression.key} {gen(expression.args.values())}"
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
exp.DPipe: lambda e: _binary(e, "||"),
exp.SafeDPipe: lambda e: _binary(e, "||"),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),
exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
exp.ILike: lambda e: _binary(e, "ILIKE"),
exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
exp.Is: lambda e: _binary(e, "IS"),
exp.Like: lambda e: _binary(e, "LIKE"),
exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
exp.LT: lambda e: _binary(e, "<"),
exp.LTE: lambda e: _binary(e, "<="),
exp.Mod: lambda e: _binary(e, "%"),
exp.Mul: lambda e: _binary(e, "*"),
exp.Neg: lambda e: _unary(e, "-"),
exp.NEQ: lambda e: _binary(e, "<>"),
exp.Not: lambda e: _unary(e, "NOT"),
exp.Null: lambda e: "NULL",
exp.Or: lambda e: _binary(e, "OR"),
exp.Paren: lambda e: f"({gen(e.this)})",
exp.Sub: lambda e: _binary(e, "-"),
exp.Subquery: lambda e: f"({gen(e.args.values())})",
exp.Table: lambda e: gen(e.args.values()),
exp.Var: lambda e: e.name,
}
def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"
def _unary(e: exp.Unary, op: str) -> str:
return f"{op} {gen(e.this)}"