1
0
Fork 0

Adding upstream version 6.0.4.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 06:15:54 +01:00
parent d01130b3f1
commit 527597d2af
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
122 changed files with 23162 additions and 0 deletions

View file

@ -0,0 +1,2 @@
from sqlglot.optimizer.optimizer import optimize
from sqlglot.optimizer.schema import Schema

View file

@ -0,0 +1,48 @@
import itertools
from sqlglot import alias, exp, select, table
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
Rewrite duplicate subqueries from sqlglot AST.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
>>> eliminate_subqueries(expression).sql()
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
"""
expression = simplify(expression)
queries = {}
for scope in traverse_scope(expression):
query = scope.expression
queries[query] = queries.get(query, []) + [query]
sequence = itertools.count()
for query, duplicates in queries.items():
if len(duplicates) == 1:
continue
alias_ = f"_e_{next(sequence)}"
for dup in duplicates:
parent = dup.parent
if isinstance(parent, exp.Subquery):
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
elif isinstance(parent, exp.Union):
dup.replace(select("*").from_(alias_))
expression.with_(alias_, as_=query, copy=False)
return expression

View file

@ -0,0 +1,16 @@
from sqlglot import exp
def expand_multi_table_selects(expression):
for from_ in expression.find_all(exp.From):
parent = from_.parent
for query in from_.expressions[1:]:
parent.join(
query,
join_type="CROSS",
copy=False,
)
from_.expressions.remove(query)
return expression

View file

@ -0,0 +1,31 @@
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.scope import traverse_scope
def isolate_table_selects(expression):
for scope in traverse_scope(expression):
if len(scope.selected_sources) == 1:
continue
for (_, source) in scope.selected_sources.values():
if not isinstance(source, exp.Table):
continue
if not isinstance(source.parent, exp.Alias):
raise OptimizeError(
"Tables require an alias. Run qualify_tables optimization."
)
parent = source.parent
parent.replace(
exp.select("*")
.from_(
alias(source, source.name or parent.alias, table=True),
copy=False,
)
.subquery(parent.alias, copy=False)
)
return expression

View file

@ -0,0 +1,136 @@
from sqlglot import exp
from sqlglot.helper import while_changing
from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
def normalize(expression, dnf=False, max_distance=128):
"""
Rewrite sqlglot AST into conjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression).sql()
'(x OR z) AND (y OR z)'
Args:
expression (sqlglot.Expression): expression to normalize
dnf (bool): rewrite in disjunctive normal form instead
max_distance (int): the maximal estimated distance from cnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
expression = simplify(expression)
expression = while_changing(
expression, lambda e: distributive_law(e, dnf, max_distance)
)
return simplify(expression)
def normalized(expression, dnf=False):
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
return not any(
connector.find_ancestor(ancestor) for connector in expression.find_all(root)
)
def normalization_distance(expression, dnf=False):
"""
The difference in the number of predicates between the current expression and the normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Args:
expression (sqlglot.Expression): expression to compute distance
dnf (bool): compute to dnf distance instead
Returns:
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
len(list(expression.find_all(exp.Connector))) + 1
)
def _predicate_lengths(expression, dnf):
"""
Returns a list of predicate lengths when expanded to normalized form.
(A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
"""
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
return [1]
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
x = [
a + b
for a in _predicate_lengths(left, dnf)
for b in _predicate_lengths(right, dnf)
]
return x
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
if isinstance(expression.unnest(), exp.Connector):
if normalization_distance(expression, dnf) > max_distance:
return expression
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
from_func = exp.and_ if from_exp == exp.And else exp.or_
to_func = exp.and_ if to_exp == exp.And else exp.or_
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(
tuple(b.find_all(exp.Connector))
):
return _distribute(a, b, from_func, to_func)
return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
return _distribute(a, b, from_func, to_func)
return expression
def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
exp.paren(from_func(c, b.left)),
exp.paren(from_func(c, b.right)),
),
)
else:
a = to_func(from_func(a, b.left), from_func(a, b.right))
return _simplify(a)
def _simplify(node):
node = uniq_sort(flatten(node))
exp.replace_children(node, _simplify)
return node

View file

@ -0,0 +1,75 @@
from sqlglot import exp
from sqlglot.helper import tsort
from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
"""
Removes cross joins if possible and reorder joins based on predicate dependencies.
"""
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
name = join.this.alias_or_name
tables = other_table_names(join, name)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
cross_joins.append((name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.TRUE)
join.on(predicate, copy=False)
expression = reorder_joins(expression)
expression = normalize(expression)
return expression
def reorder_joins(expression):
"""
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.expressions[0]
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
for name, join in joins.items():
dag[name] = other_table_names(join, name)
parent.set(
"joins",
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
)
return expression
def normalize(expression):
"""
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
if join.kind != "CROSS":
join.set("kind", None)
return expression
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
if name != exclude
]

View file

@ -0,0 +1,43 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.quote_identities import quote_identities
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
def optimize(expression, schema=None, db=None, catalog=None):
"""
Rewrite a sqlglot AST into an optimized form.
Args:
expression (sqlglot.Expression): expression to optimize
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
Returns:
sqlglot.Expression: optimized expression
"""
expression = expression.copy()
expression = qualify_tables(expression, db=db, catalog=catalog)
expression = isolate_table_selects(expression)
expression = qualify_columns(expression, schema)
expression = pushdown_projections(expression)
expression = normalize(expression)
expression = unnest_subqueries(expression)
expression = expand_multi_table_selects(expression)
expression = pushdown_predicates(expression)
expression = optimize_joins(expression)
expression = eliminate_subqueries(expression)
expression = quote_identities(expression)
return expression

View file

@ -0,0 +1,176 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
for scope in reversed(traverse_scope(expression)):
select = scope.expression
where = select.args.get("where")
if where:
pushdown(where.this, scope.selected_sources)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
return expression
def pushdown(condition, sources):
if not condition:
return
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(
condition.flatten()
if isinstance(condition, exp.And if cnf_like else exp.Or)
else [condition]
)
if cnf_like:
pushdown_cnf(predicates, sources)
else:
pushdown_dnf(predicates, sources)
def pushdown_cnf(predicates, scope):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope).values():
if isinstance(node, exp.Join):
predicate.replace(exp.TRUE)
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
predicate.replace(exp.TRUE)
node.where(replace_aliases(node, predicate), copy=False)
def pushdown_dnf(predicates, scope):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
"""
# find all the tables that can be pushdown too
# these are tables that are referenced in all blocks of a DNF
# (a.x AND b.x) OR (a.y AND c.y)
# only table a can be push down
pushdown_tables = set()
for a in predicates:
a_tables = set(exp.column_table_names(a))
for b in predicates:
a_tables &= set(exp.column_table_names(b))
pushdown_tables.update(a_tables)
conditions = {}
# for every pushdown table, find all related conditions in all predicates
# combine them with ORS
# (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope)
if table not in nodes:
continue
predicate_condition = None
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
predicate_condition = (
exp.and_(predicate_condition, condition)
if predicate_condition
else condition
)
if predicate_condition:
conditions[table] = (
exp.or_(conditions[table], predicate_condition)
if table in conditions
else predicate_condition
)
for name, node in nodes.items():
if name not in conditions:
continue
predicate = conditions[name]
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
def nodes_for_predicate(predicate, sources):
nodes = {}
tables = exp.column_table_names(predicate)
where_condition = isinstance(
predicate.find_ancestor(exp.Join, exp.Where), exp.Where
)
for table in tables:
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
# we want to find the root join or from statement
if node and where_condition:
node = node.find_ancestor(exp.Join, exp.From)
# a node can reference a CTE which should be push down
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
node = source.expression
if isinstance(node, exp.Join):
if node.side:
return {}
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
if not node.args.get("group"):
nodes[table] = node
return nodes
def replace_aliases(source, predicate):
aliases = {}
for select in source.selects:
if isinstance(select, exp.Alias):
aliases[select.alias] = select.this
else:
aliases[select.name] = select
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
return aliases[column.name]
return column
return predicate.transform(_replace_alias)

View file

@ -0,0 +1,85 @@
from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.optimizer.scope import Scope, traverse_scope
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
def pushdown_projections(expression):
"""
Rewrite sqlglot AST to remove unused columns projections.
Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Args:
expression (sqlglot.Expression): expression to optimize
Returns:
sqlglot.Expression: optimized expression
"""
# Map of Scope to all columns being selected by outer queries.
referenced_columns = defaultdict(set)
# We build the scope tree (which is traversed in DFS postorder), then iterate
# over the result in reverse order. This should ensure that the set of selected
# columns for a particular scope are completely build by the time we get to it.
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
if scope.expression.args.get("distinct"):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union
referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections
if isinstance(scope.expression, exp.Select):
_remove_unused_selections(scope, parent_selections)
# Group columns by source name
selects = defaultdict(set)
for col in scope.columns:
table_name = col.table
col_name = col.name
selects[table_name].add(col_name)
# Push the selected columns down to the next scope
for name, (_, source) in scope.selected_sources.items():
if isinstance(source, Scope):
columns = selects.get(name) or set()
referenced_columns[source].update(columns)
return expression
def _remove_unused_selections(scope, parent_selections):
order = scope.expression.args.get("order")
if order:
# Assume columns without a qualified table are references to output columns
order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
else:
order_refs = set()
new_selections = []
for selection in scope.selects:
if (
SELECT_ALL in parent_selections
or selection.alias_or_name in parent_selections
or selection.alias_or_name in order_refs
):
new_selections.append(selection)
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(alias("1", "_"))
scope.expression.set("expressions", new_selections)

View file

@ -0,0 +1,422 @@
import itertools
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.schema import ensure_schema
from sqlglot.optimizer.scope import traverse_scope
SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
def qualify_columns(expression, schema):
"""
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
"""
schema = ensure_schema(schema)
for scope in traverse_scope(expression):
resolver = _Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
_expand_using(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, SKIP_QUALIFY):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
return expression
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
(e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
"""
for derived_table in derived_tables:
if isinstance(derived_table, SKIP_QUALIFY):
continue
table_alias = derived_table.args.get("alias")
if table_alias:
table_alias.args.pop("columns", None)
def _expand_using(scope, resolver):
joins = list(scope.expression.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to source names
column_tables = {}
for join in joins:
using = join.args.get("using")
if not using:
continue
join_table = join.this.alias_or_name
columns = {}
for k in scope.selected_sources:
if k in ordered:
for column in resolver.get_source_columns(k):
if column not in columns:
columns[column] = k
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
for identifier in using:
identifier = identifier.name
table = columns.get(identifier)
if not table or identifier not in join_columns:
raise OptimizeError(f"Cannot automatically join: {identifier}")
conditions.append(
exp.condition(
exp.EQ(
this=exp.column(identifier, table=table),
expression=exp.column(identifier, table=join_table),
)
)
)
tables = column_tables.setdefault(identifier, [])
if table not in tables:
tables.append(table)
if join_table not in tables:
tables.append(join_table)
join.args.pop("using")
join.set("on", exp.and_(*conditions))
if column_tables:
for column in scope.columns:
if not column.table and column.name in column_tables:
tables = column_tables[column.name]
coalesce = [exp.column(column.name, table=table) for table in tables]
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
replacement = exp.alias_(replacement, alias=column.name)
scope.replace(column, replacement)
def _expand_group_by(scope, resolver):
group = scope.expression.args.get("group")
if not group:
return
# Replace references to select aliases
def transform(node, *_):
if isinstance(node, exp.Column) and not node.table:
table = resolver.get_table(node.name)
# Source columns get priority over select aliases
if table:
node.set("table", exp.to_identifier(table))
return node
selects = {s.alias_or_name: s for s in scope.selects}
select = selects.get(node.name)
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
select = select.this
return select.copy()
return node
group.transform(transform, copy=False)
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
def _expand_order_by(scope):
order = scope.expression.args.get("order")
if not order:
return
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
_expand_positional_references(scope, (o.this for o in ordereds)),
):
ordered.set("this", new_expression)
def _expand_positional_references(scope, expressions):
new_nodes = []
for node in expressions:
if node.is_int:
try:
select = scope.selects[int(node.name) - 1]
except IndexError:
raise OptimizeError(f"Unknown output column: {node.name}")
if isinstance(select, exp.Alias):
select = select.this
new_nodes.append(select.copy())
scope.clear_cache()
else:
new_nodes.append(node)
return new_nodes
def _qualify_columns(scope, resolver):
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
column_name = column.name
if (
column_table
and column_table in scope.sources
and column_name not in resolver.get_source_columns(column_table)
):
raise OptimizeError(f"Unknown column: {column_name}")
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_unnest:
if column_name not in resolver.all_columns:
raise OptimizeError(f"Unknown column: {column_name}")
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
elif isinstance(expression, exp.Column) and isinstance(
expression.this, exp.Star
):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
else:
new_selections.append(expression)
continue
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table)
table_id = id(table)
for name in columns:
if name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(
alias(column, alias_) if alias_ != name else column
)
scope.expression.set("expressions", new_selections)
def _add_except_columns(expression, tables, except_columns):
except_ = expression.args.get("except")
if not except_:
return
columns = {e.name for e in except_}
for table in tables:
except_columns[id(table)] = columns
def _add_replace_columns(expression, tables, replace_columns):
replace = expression.args.get("replace")
if not replace:
return
columns = {e.this.name: e.alias for e in replace}
for table in tables:
replace_columns[id(table)] = columns
def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.selects, scope.outer_column_list)
):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
alias_.set("this", selection)
selection = alias_
elif not isinstance(selection, exp.Alias):
alias_ = alias(exp.column(""), f"_col_{i}")
alias_.set("this", selection)
selection = alias_
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
new_selections.append(selection)
scope.expression.set("expressions", new_selections)
def _check_unknown_tables(scope):
if (
scope.external_columns
and not scope.is_unnest
and not scope.is_correlated_subquery
):
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
class _Resolver:
"""
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
"""
def __init__(self, scope, schema):
self.scope = scope
self.schema = schema
self._source_columns = None
self._unambiguous_columns = None
self._all_columns = None
def get_table(self, column_name):
"""
Get the table for a column name.
Args:
column_name (str)
Returns:
(str) table name
"""
if self._unambiguous_columns is None:
self._unambiguous_columns = self._get_unambiguous_columns(
self._get_all_source_columns()
)
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
self._all_columns = set(
column
for columns in self._get_all_source_columns().values()
for column in columns
)
return self._all_columns
def get_source_columns(self, name):
"""Resolve the source columns for a given source `name`"""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
# If referencing a table, return the columns from the schema
if isinstance(source, exp.Table):
try:
return self.schema.column_names(source)
except Exception as e:
raise OptimizeError(str(e)) from e
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
def _get_all_source_columns(self):
if self._source_columns is None:
self._source_columns = {
k: self.get_source_columns(k) for k in self.scope.selected_sources
}
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
"""
Find all the unambiguous columns in sources.
Args:
source_columns (dict): Mapping of names to source columns
Returns:
dict: Mapping of column name to source name
"""
if not source_columns:
return {}
source_columns = list(source_columns.items())
first_table, first_columns = source_columns[0]
unambiguous_columns = {
col: first_table for col in self._find_unique_columns(first_columns)
}
all_columns = set(unambiguous_columns)
for table, columns in source_columns[1:]:
unique = self._find_unique_columns(columns)
ambiguous = set(all_columns).intersection(unique)
all_columns.update(columns)
for column in ambiguous:
unambiguous_columns.pop(column, None)
for column in unique.difference(ambiguous):
unambiguous_columns[column] = table
return unambiguous_columns
@staticmethod
def _find_unique_columns(columns):
"""
Find the unique columns in a list of columns.
Example:
>>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
['a', 'c']
This is necessary because duplicate column names are ambiguous.
"""
counts = {}
for column in columns:
counts[column] = counts.get(column, 0) + 1
return {column for column, count in counts.items() if count == 1}

View file

@ -0,0 +1,54 @@
import itertools
from sqlglot import alias, exp
from sqlglot.optimizer.scope import traverse_scope
def qualify_tables(expression, db=None, catalog=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
Returns:
sqlglot.Expression: qualified expression
"""
sequence = itertools.count()
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set(
"alias", exp.TableAlias(this=exp.to_identifier(alias_))
)
scope.rename_source(None, alias_)
for source in scope.sources.values():
if isinstance(source, exp.Table):
identifier = isinstance(source.this, exp.Identifier)
if identifier:
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
if not isinstance(source.parent, exp.Alias):
source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
table=True,
)
)
return expression

View file

@ -0,0 +1,25 @@
from sqlglot import exp
def quote_identities(expression):
"""
Rewrite sqlglot AST to ensure all identities are quoted.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
>>> quote_identities(expression).sql()
'SELECT "x"."a" AS "a" FROM "db"."x"'
Args:
expression (sqlglot.Expression): expression to quote
Returns:
sqlglot.Expression: quoted expression
"""
def qualify(node):
if isinstance(node, exp.Identifier):
node.set("quoted", True)
return node
return expression.transform(qualify, copy=False)

129
sqlglot/optimizer/schema.py Normal file
View file

@ -0,0 +1,129 @@
import abc
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import csv_reader
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
def column_names(self, table):
"""
Get the column names for a table.
Args:
table (sqlglot.expressions.Table): Table expression instance
Returns:
list[str]: list of column names
"""
class MappingSchema(Schema):
"""
Schema based on a nested mapping.
Args:
schema (dict): Mapping in one of the following forms:
1. {table: {col: type}}
2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}}
"""
def __init__(self, schema):
self.schema = schema
depth = _dict_depth(schema)
if not depth: # {}
self.supported_table_args = []
elif depth == 2: # {table: {col: type}}
self.supported_table_args = ("this",)
elif depth == 3: # {db: {table: {col: type}}}
self.supported_table_args = ("db", "this")
elif depth == 4: # {catalog: {db: {table: {col: type}}}}
self.supported_table_args = ("catalog", "db", "this")
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def column_names(self, table):
if not isinstance(table.this, exp.Identifier):
return fs_get(table)
args = tuple(table.text(p) for p in self.supported_table_args)
for forbidden in self.forbidden_args:
if table.text(forbidden):
raise ValueError(
f"Schema doesn't support {forbidden}. Received: {table.sql()}"
)
return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
def ensure_schema(schema):
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
def fs_get(table):
name = table.this.name.upper()
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:
return next(reader)
raise ValueError(f"Cannot read schema for {table}")
def _nested_get(d, *path):
"""
Get a value for a nested dictionary.
Args:
d (dict): dictionary
*path (tuple[str, str]): tuples of (name, key)
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
"""
for name, key in path:
d = d.get(key)
if d is None:
name = "table" if name == "this" else name
raise ValueError(f"Unknown {name}")
return d
def _dict_depth(d):
"""
Get the nesting depth of a dictionary.
For example:
>>> _dict_depth(None)
0
>>> _dict_depth({})
1
>>> _dict_depth({"a": "b"})
1
>>> _dict_depth({"a": {}})
2
>>> _dict_depth({"a": {"b": {}}})
3
Args:
d (dict): dictionary
Returns:
int: depth
"""
try:
return 1 + _dict_depth(next(iter(d.values())))
except AttributeError:
# d doesn't have attribute "values"
return 0
except StopIteration:
# d.values() returns an empty sequence
return 1

438
sqlglot/optimizer/scope.py Normal file
View file

@ -0,0 +1,438 @@
from copy import copy
from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
class ScopeType(Enum):
ROOT = auto()
SUBQUERY = auto()
DERIVED_TABLE = auto()
CTE = auto()
UNION = auto()
UNNEST = auto()
class Scope:
"""
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
This does not include derived tables or CTEs.
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
a tuple of the left and right child scopes.
"""
def __init__(
self,
expression,
sources=None,
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
):
self.expression = expression
self.sources = sources or {}
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.union = None
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
self._tables = None
self._ctes = None
self._subqueries = None
self._selected_sources = None
self._columns = None
self._external_columns = None
def branch(self, expression, scope_type, add_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
sources = copy(self.sources)
if add_sources:
sources.update(add_sources)
return Scope(
expression=expression.unnest(),
sources=sources,
parent=self,
scope_type=scope_type,
**kwargs,
)
def _collect(self):
self._tables = []
self._ctes = []
self._subqueries = []
self._derived_tables = []
self._raw_columns = []
# We'll use this variable to pass state into the dfs generator.
# Whenever we set it to True, we exclude a subtree from traversal.
prune = False
for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
prune = False
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
elif isinstance(node, exp.Table):
self._tables.append(node)
elif isinstance(node, (exp.Unnest, exp.Lateral)):
self._derived_tables.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
prune = True
elif isinstance(node, exp.Subquery) and isinstance(
parent, (exp.From, exp.Join)
):
self._derived_tables.append(node)
prune = True
elif isinstance(node, exp.Subqueryable):
self._subqueries.append(node)
prune = True
self._collected = True
def _ensure_collected(self):
if not self._collected:
self._collect()
def replace(self, old, new):
"""
Replace `old` with `new`.
This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
Args:
old (exp.Expression): old node
new (exp.Expression): new node
"""
old.replace(new)
self.clear_cache()
@property
def tables(self):
"""
List of tables in this scope.
Returns:
list[exp.Table]: tables
"""
self._ensure_collected()
return self._tables
@property
def ctes(self):
"""
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
"""
self._ensure_collected()
return self._ctes
@property
def derived_tables(self):
"""
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
"""
self._ensure_collected()
return self._derived_tables
@property
def subqueries(self):
"""
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
"""
self._ensure_collected()
return self._subqueries
@property
def columns(self):
"""
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any
Columns that reference this scope from correlated subqueries.
"""
if self._columns is None:
self._ensure_collected()
columns = self._raw_columns
external_columns = [
column
for scope in self.subquery_scopes
for column in scope.external_columns
]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
if not (
c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
)
]
return self._columns
@property
def selected_sources(self):
"""
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a
table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
"""
if self._selected_sources is None:
referenced_names = []
for table in self.tables:
referenced_names.append(
(
table.parent.alias
if isinstance(table.parent, exp.Alias)
else table.name,
table,
)
)
for derived_table in self.derived_tables:
referenced_names.append((derived_table.alias, derived_table.unnest()))
result = {}
for name, node in referenced_names:
if name in self.sources:
result[name] = (node, self.sources[name])
self._selected_sources = result
return self._selected_sources
@property
def selects(self):
"""
Select expressions of this scope.
For example, for the following expression:
SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
"""
if isinstance(self.expression, exp.Union):
return []
return self.expression.selects
@property
def external_columns(self):
"""
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference
sources in the current scope.
"""
if self._external_columns is None:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
]
return self._external_columns
def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Args:
source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference `source_name`
"""
return [column for column in self.columns if column.table == source_name]
@property
def is_subquery(self):
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
return self.scope_type == ScopeType.UNNEST
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(self.is_subquery and self.external_columns)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
columns = self.sources.pop(old_name or "", [])
self.sources[new_name] = columns
def traverse_scope(expression):
"""
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than
the expression tree itself. For example, we might care about the source
names within a subquery. Returns a list because a generator could result in
incomplete properties which is confusing.
Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Args:
expression (exp.Expression): expression to traverse
Returns:
List[Scope]: scope instances
"""
return list(_traverse_scope(Scope(expression)))
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
pass
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope
def _traverse_select(scope):
yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
yield from _traverse_subqueries(scope)
yield from _traverse_derived_tables(
scope.derived_tables, scope, ScopeType.DERIVED_TABLE
)
_add_table_sources(scope)
def _traverse_union(scope):
yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
):
yield left
right = None
for right in _traverse_scope(
scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
):
yield right
scope.union = (left, right)
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
for child_scope in _traverse_scope(
scope.branch(
derived_table
if isinstance(derived_table, (exp.Unnest, exp.Lateral))
else derived_table.this,
add_sources=sources if scope_type == ScopeType.CTE else None,
outer_column_list=derived_table.alias_column_names,
scope_type=ScopeType.UNNEST
if isinstance(derived_table, exp.Unnest)
else scope_type,
)
):
yield child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
scope.sources.update(sources)
def _add_table_sources(scope):
sources = {}
for table in scope.tables:
table_name = table.name
if isinstance(table.parent, exp.Alias):
source_name = table.parent.alias
else:
source_name = table_name
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table
scope.sources.update(sources)
def _traverse_subqueries(scope):
for subquery in scope.subqueries:
top = None
for child_scope in _traverse_scope(
scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
):
yield child_scope
top = child_scope
scope.subquery_scopes.append(top)

View file

@ -0,0 +1,383 @@
import datetime
import functools
import itertools
from collections import deque
from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import while_changing
GENERATOR = Generator(normalize=True, identify=True)
def simplify(expression):
"""
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Args:
expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
"""
def _simplify(expression, root=True):
node = expression
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node)
node = remove_compliments(node)
node.parent = expression.parent
node = simplify_literals(node)
node = simplify_parens(node)
if root:
expression.replace(node)
return node
expression = while_changing(expression, _simplify)
remove_where_true(expression)
return expression
def simplify_not(expression):
"""
Demorgan's Law
NOT (x OR y) -> NOT x AND NOT y
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if always_true(expression.this):
return FALSE
if expression.this == FALSE:
return TRUE
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
return expression.this.this
return expression
def flatten(expression):
"""
A AND (B AND C) -> A AND B AND C
A OR (B OR C) -> A OR B OR C
"""
if isinstance(expression, exp.Connector):
for node in expression.args.values():
child = node.unnest()
if isinstance(child, expression.__class__):
node.replace(child)
return expression
def simplify_connectors(expression):
if isinstance(expression, exp.Connector):
left = expression.left
right = expression.right
if left == right:
return left
if isinstance(expression, exp.And):
if NULL in (left, right):
return NULL
if FALSE in (left, right):
return FALSE
if always_true(left) and always_true(right):
return TRUE
if always_true(left):
return right
if always_true(right):
return left
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return TRUE
if left == FALSE and right == FALSE:
return FALSE
if (
(left == NULL and right == NULL)
or (left == NULL and right == FALSE)
or (left == FALSE and right == NULL)
):
return NULL
if left == FALSE:
return right
if right == FALSE:
return left
return expression
def remove_compliments(expression):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
compliment = FALSE if isinstance(expression, exp.And) else TRUE
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
return compliment
return expression
def uniq_sort(expression):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {GENERATOR.generate(e): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
expression = result_func(*deduped.values())
return expression
def absorb_and_eliminate(expression):
"""
absorption:
A AND (A OR B) -> A
A OR (A AND B) -> A
A AND (NOT A OR B) -> A AND B
A OR (NOT A AND B) -> A OR B
elimination:
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
if isinstance(expression, exp.Connector):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
if isinstance(a, kind):
aa, ab = a.unnest_operands()
# absorb
if is_complement(b, aa):
aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif is_complement(b, ab):
ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
a.flatten()
):
a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
ba, bb = rhs
if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
a.replace(aa)
b.replace(aa)
elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
a.replace(ab)
b.replace(ab)
return expression
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
operands = []
queue = deque(expression.flatten(unnest=False))
size = len(queue)
while queue:
a = queue.popleft()
for b in queue:
result = _simplify_binary(expression, a, b)
if result:
queue.remove(b)
queue.append(result)
break
else:
operands.append(a)
if len(operands) < size:
return functools.reduce(
lambda a, b: expression.__class__(this=a, expression=b), operands
)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
if value[0] == "-":
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
return expression
def _simplify_binary(expression, a, b):
if isinstance(expression, exp.Is):
if isinstance(b, exp.Not):
c = b.this
not_ = True
else:
c = b
not_ = False
if c == NULL:
if isinstance(a, exp.Literal):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
elif NULL in (a, b):
return NULL
if isinstance(expression, exp.EQ) and a == b:
return TRUE
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
return exp.Literal.number(a + b)
if isinstance(expression, exp.Sub):
return exp.Literal.number(a - b)
if isinstance(expression, exp.Mul):
return exp.Literal.number(a * b)
if isinstance(expression, exp.Div):
if isinstance(a, int) and isinstance(b, int):
return exp.Literal.number(a // b)
return exp.Literal.number(a / b)
boolean = eval_boolean(expression, a, b)
if boolean:
return boolean
elif a.is_string and b.is_string:
boolean = eval_boolean(expression, a, b)
if boolean:
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and isinstance(expression, exp.Add):
return date_literal(a + b)
return None
def simplify_parens(expression):
if (
isinstance(expression, exp.Paren)
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
or isinstance(expression.this, (exp.Is, exp.Like))
or not isinstance(expression.this, exp.Binary)
)
):
return expression.this
return expression
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
where.parent.set("where", None)
for join in expression.find_all(exp.Join):
if always_true(join.args.get("on")):
join.set("kind", "CROSS")
join.set("on", None)
def always_true(expression):
return expression == TRUE or isinstance(expression, exp.Literal)
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
if isinstance(expression, exp.NEQ):
return boolean_literal(a != b)
if isinstance(expression, exp.GT):
return boolean_literal(a > b)
if isinstance(expression, exp.GTE):
return boolean_literal(a >= b)
if isinstance(expression, exp.LT):
return boolean_literal(a < b)
if isinstance(expression, exp.LTE):
return boolean_literal(a <= b)
return None
def extract_date(cast):
if cast.args["to"].this == exp.DataType.Type.DATE:
return datetime.date.fromisoformat(cast.name)
return None
def extract_interval(interval):
try:
from dateutil.relativedelta import relativedelta
except ModuleNotFoundError:
return None
n = int(interval.name)
unit = interval.text("unit").lower()
if unit == "year":
return relativedelta(years=n)
if unit == "month":
return relativedelta(months=n)
if unit == "week":
return relativedelta(weeks=n)
if unit == "day":
return relativedelta(days=n)
return None
def date_literal(date):
return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
def boolean_literal(condition):
return TRUE if condition else FALSE

View file

@ -0,0 +1,220 @@
import itertools
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
def unnest_subqueries(expression):
"""
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert the subquery into a group by so it is not a many to many left join.
Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
Args:
expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
"""
sequence = itertools.count()
for scope in traverse_scope(expression):
select = scope.expression
parent = select.parent_select
if scope.external_columns:
decorrelate(select, parent, scope.external_columns, sequence)
else:
unnest(select, parent, sequence)
return expression
def unnest(select, parent_select, sequence):
predicate = select.find_ancestor(exp.In, exp.Any)
if not predicate or parent_select is not predicate.parent_select:
return
if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
return
if isinstance(predicate, exp.Any):
predicate = predicate.find_ancestor(exp.EQ)
if not predicate or parent_select is not predicate.parent_select:
return
column = _other_operand(predicate)
value = select.selects[0]
alias = _alias(sequence)
on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
_replace(predicate, f"NOT {on.right} IS NULL")
parent_select.join(
select.group_by(value.this, copy=False),
on=on,
join_type="LEFT",
join_alias=alias,
copy=False,
)
def decorrelate(select, parent_select, external_columns, sequence):
where = select.args.get("where")
if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
return
table_alias = _alias(sequence)
keys = []
# for all external columns in the where statement,
# split out the relevant data to convert it into a join
for column in external_columns:
if column.find_ancestor(exp.Where) is not where:
return
predicate = column.find_ancestor(exp.Predicate)
if not predicate or predicate.find_ancestor(exp.Where) is not where:
return
if isinstance(predicate, exp.Binary):
key = (
predicate.right
if any(node is column for node, *_ in predicate.left.walk())
else predicate.left
)
else:
return
keys.append((key, column, predicate))
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
return
value = select.selects[0]
key_aliases = {}
group_by = []
for key, _, predicate in keys:
# if we filter on the value of the subquery, it needs to be unique
if key == value.this:
key_aliases[key] = value.alias
group_by.append(key)
else:
if key not in key_aliases:
key_aliases[key] = _alias(sequence)
# all predicates that are equalities must also be in the unique
# so that we don't do a many to many join
if isinstance(predicate, exp.EQ) and key not in group_by:
group_by.append(key)
parent_predicate = select.find_ancestor(exp.Predicate)
# if the value of the subquery is not an agg or a key, we need to collect it into an array
# so that it can be grouped
if not value.find(exp.AggFunc) and value.this not in group_by:
select.select(
f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
)
# exists queries should not have any selects as it only checks if there are any rows
# all selects will be added by the optimizer and only used for join keys
if isinstance(parent_predicate, exp.Exists):
select.args["expressions"] = []
for key, alias in key_aliases.items():
if key in group_by:
# add all keys to the projections of the subquery
# so that we can use it as a join key
if isinstance(parent_predicate, exp.Exists) or key != value.this:
select.select(f"{key} AS {alias}", copy=False)
else:
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
else:
parent_predicate = _replace(
parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
)
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
else:
parent_predicate = _replace(
parent_predicate,
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
select.parent.replace(alias)
for key, column, predicate in keys:
predicate.replace(exp.TRUE)
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:
key.replace(nested)
parent_predicate = _replace(
parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
)
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
)
else:
key.replace(exp.to_identifier("_x"))
parent_predicate = _replace(
parent_predicate,
f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
)
parent_select.join(
select.group_by(*group_by, copy=False),
on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
join_type="LEFT",
join_alias=table_alias,
copy=False,
)
def _alias(sequence):
return f"_u_{next(sequence)}"
def _replace(expression, condition):
return expression.replace(exp.condition(condition))
def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
if isinstance(expression, exp.Binary):
return expression.right if expression.arg_key == "this" else expression.left
return None