Adding upstream version 25.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
147b6e06e8
commit
4e506fbac7
136 changed files with 80990 additions and 72541 deletions
|
@ -9,6 +9,52 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.generator import Generator
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
|
||||
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
|
||||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
|
||||
Returns:
|
||||
Function that can be used as a generator transform.
|
||||
"""
|
||||
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
|
||||
expression = transforms[0](expression)
|
||||
for transform in transforms[1:]:
|
||||
expression = transform(expression)
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
return _sql_handler(expression)
|
||||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
if expression_type is type(expression):
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
|
||||
return transforms_handler(self, expression)
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Replace references to select aliases in GROUP BY clauses.
|
||||
|
@ -393,7 +439,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
|
|||
for cte in expression.expressions:
|
||||
if not cte.args["alias"].columns:
|
||||
query = cte.this
|
||||
if isinstance(query, exp.Union):
|
||||
if isinstance(query, exp.SetOperation):
|
||||
query = query.this
|
||||
|
||||
cte.args["alias"].set(
|
||||
|
@ -623,47 +669,103 @@ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
|
||||
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
|
||||
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
|
||||
If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
|
||||
|
||||
For example,
|
||||
SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to
|
||||
SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
|
||||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
expression: The AST to remove join marks from.
|
||||
|
||||
Returns:
|
||||
Function that can be used as a generator transform.
|
||||
The AST with join marks removed.
|
||||
"""
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
for scope in traverse_scope(expression):
|
||||
query = scope.expression
|
||||
|
||||
expression = transforms[0](expression)
|
||||
for transform in transforms[1:]:
|
||||
expression = transform(expression)
|
||||
where = query.args.get("where")
|
||||
joins = query.args.get("joins")
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
return _sql_handler(expression)
|
||||
if not where or not joins:
|
||||
continue
|
||||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
if expression_type is type(expression):
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
query_from = query.args["from"]
|
||||
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
# These keep track of the joins to be replaced
|
||||
new_joins: t.Dict[str, exp.Join] = {}
|
||||
old_joins = {join.alias_or_name: join for join in joins}
|
||||
|
||||
return transforms_handler(self, expression)
|
||||
for column in scope.columns:
|
||||
if not column.args.get("join_mark"):
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
predicate = column.find_ancestor(exp.Predicate, exp.Select)
|
||||
assert isinstance(
|
||||
predicate, exp.Binary
|
||||
), "Columns can only be marked with (+) when involved in a binary operation"
|
||||
|
||||
return _to_sql
|
||||
predicate_parent = predicate.parent
|
||||
join_predicate = predicate.pop()
|
||||
|
||||
left_columns = [
|
||||
c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
|
||||
]
|
||||
right_columns = [
|
||||
c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
|
||||
]
|
||||
|
||||
assert not (
|
||||
left_columns and right_columns
|
||||
), "The (+) marker cannot appear in both sides of a binary predicate"
|
||||
|
||||
marked_column_tables = set()
|
||||
for col in left_columns or right_columns:
|
||||
table = col.table
|
||||
assert table, f"Column {col} needs to be qualified with a table"
|
||||
|
||||
col.set("join_mark", False)
|
||||
marked_column_tables.add(table)
|
||||
|
||||
assert (
|
||||
len(marked_column_tables) == 1
|
||||
), "Columns of only a single table can be marked with (+) in a given binary predicate"
|
||||
|
||||
join_this = old_joins.get(col.table, query_from).this
|
||||
new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
|
||||
|
||||
# Upsert new_join into new_joins dictionary
|
||||
new_join_alias_or_name = new_join.alias_or_name
|
||||
existing_join = new_joins.get(new_join_alias_or_name)
|
||||
if existing_join:
|
||||
existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
|
||||
else:
|
||||
new_joins[new_join_alias_or_name] = new_join
|
||||
|
||||
# If the parent of the target predicate is a binary node, then it now has only one child
|
||||
if isinstance(predicate_parent, exp.Binary):
|
||||
if predicate_parent.left is None:
|
||||
predicate_parent.replace(predicate_parent.right)
|
||||
else:
|
||||
predicate_parent.replace(predicate_parent.left)
|
||||
|
||||
if query_from.alias_or_name in new_joins:
|
||||
only_old_joins = old_joins.keys() - new_joins.keys()
|
||||
assert (
|
||||
len(only_old_joins) >= 1
|
||||
), "Cannot determine which table to use in the new FROM clause"
|
||||
|
||||
new_from_name = list(only_old_joins)[0]
|
||||
query.set("from", exp.From(this=old_joins[new_from_name].this))
|
||||
|
||||
query.set("joins", list(new_joins.values()))
|
||||
|
||||
if not where.this:
|
||||
where.pop()
|
||||
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue