75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
from sqlglot import exp
|
|
from sqlglot.helper import tsort
|
|
from sqlglot.optimizer.simplify import simplify
|
|
|
|
|
|
def optimize_joins(expression):
|
|
"""
|
|
Removes cross joins if possible and reorder joins based on predicate dependencies.
|
|
"""
|
|
for select in expression.find_all(exp.Select):
|
|
references = {}
|
|
cross_joins = []
|
|
|
|
for join in select.args.get("joins", []):
|
|
name = join.this.alias_or_name
|
|
tables = other_table_names(join, name)
|
|
|
|
if tables:
|
|
for table in tables:
|
|
references[table] = references.get(table, []) + [join]
|
|
else:
|
|
cross_joins.append((name, join))
|
|
|
|
for name, join in cross_joins:
|
|
for dep in references.get(name, []):
|
|
on = dep.args["on"]
|
|
on = on.replace(simplify(on))
|
|
|
|
if isinstance(on, exp.Connector):
|
|
for predicate in on.flatten():
|
|
if name in exp.column_table_names(predicate):
|
|
predicate.replace(exp.true())
|
|
join.on(predicate, copy=False)
|
|
|
|
expression = reorder_joins(expression)
|
|
expression = normalize(expression)
|
|
return expression
|
|
|
|
|
|
def reorder_joins(expression):
|
|
"""
|
|
Reorder joins by topological sort order based on predicate references.
|
|
"""
|
|
for from_ in expression.find_all(exp.From):
|
|
head = from_.expressions[0]
|
|
parent = from_.parent
|
|
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
|
|
dag = {head.alias_or_name: []}
|
|
|
|
for name, join in joins.items():
|
|
dag[name] = other_table_names(join, name)
|
|
|
|
parent.set(
|
|
"joins",
|
|
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
|
|
)
|
|
return expression
|
|
|
|
|
|
def normalize(expression):
|
|
"""
|
|
Remove INNER and OUTER from joins as they are optional.
|
|
"""
|
|
for join in expression.find_all(exp.Join):
|
|
if join.kind != "CROSS":
|
|
join.set("kind", None)
|
|
return expression
|
|
|
|
|
|
def other_table_names(join, exclude):
|
|
return [
|
|
name
|
|
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
|
|
if name != exclude
|
|
]
|