1
0
Fork 0

Adding upstream version 25.5.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:41:00 +01:00
parent 147b6e06e8
commit 4e506fbac7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
136 changed files with 80990 additions and 72541 deletions

View file

@ -9,6 +9,52 @@ if t.TYPE_CHECKING:
from sqlglot.generator import Generator
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
"""
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
return transforms_handler(self, expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
return _to_sql
def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
Replace references to select aliases in GROUP BY clauses.
@ -393,7 +439,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
for cte in expression.expressions:
if not cte.args["alias"].columns:
query = cte.this
if isinstance(query, exp.Union):
if isinstance(query, exp.SetOperation):
query = query.this
cte.args["alias"].set(
@ -623,47 +669,103 @@ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
For example,
SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to
SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Args:
transforms: sequence of transform functions. These will be called in order.
expression: The AST to remove join marks from.
Returns:
Function that can be used as a generator transform.
The AST with join marks removed.
"""
from sqlglot.optimizer.scope import traverse_scope
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
for scope in traverse_scope(expression):
query = scope.expression
expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)
where = query.args.get("where")
joins = query.args.get("joins")
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
if not where or not joins:
continue
transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
query_from = query.args["from"]
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
# These keep track of the joins to be replaced
new_joins: t.Dict[str, exp.Join] = {}
old_joins = {join.alias_or_name: join for join in joins}
return transforms_handler(self, expression)
for column in scope.columns:
if not column.args.get("join_mark"):
continue
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
predicate = column.find_ancestor(exp.Predicate, exp.Select)
assert isinstance(
predicate, exp.Binary
), "Columns can only be marked with (+) when involved in a binary operation"
return _to_sql
predicate_parent = predicate.parent
join_predicate = predicate.pop()
left_columns = [
c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
]
right_columns = [
c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
]
assert not (
left_columns and right_columns
), "The (+) marker cannot appear in both sides of a binary predicate"
marked_column_tables = set()
for col in left_columns or right_columns:
table = col.table
assert table, f"Column {col} needs to be qualified with a table"
col.set("join_mark", False)
marked_column_tables.add(table)
assert (
len(marked_column_tables) == 1
), "Columns of only a single table can be marked with (+) in a given binary predicate"
join_this = old_joins.get(col.table, query_from).this
new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
# Upsert new_join into new_joins dictionary
new_join_alias_or_name = new_join.alias_or_name
existing_join = new_joins.get(new_join_alias_or_name)
if existing_join:
existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
else:
new_joins[new_join_alias_or_name] = new_join
# If the parent of the target predicate is a binary node, then it now has only one child
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
else:
predicate_parent.replace(predicate_parent.left)
if query_from.alias_or_name in new_joins:
only_old_joins = old_joins.keys() - new_joins.keys()
assert (
len(only_old_joins) >= 1
), "Cannot determine which table to use in the new FROM clause"
new_from_name = list(only_old_joins)[0]
query.set("from", exp.From(this=old_joins[new_from_name].this))
query.set("joins", list(new_joins.values()))
if not where.this:
where.pop()
return expression