1
0
Fork 0
sqlglot/sqlglot/dataframe/sql/normalize.py
Daniel Baumann cac8fd11fe
Adding upstream version 16.4.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
2025-02-13 20:04:17 +01:00

79 lines
3.2 KiB
Python

from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.session import SparkSession
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
expr = ensure_list(expr)
expressions = _ensure_expressions(expr)
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
Spark.normalize_identifier(identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
def replace_alias_name_with_cte_name(
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
_set_alias_name(id, cte.alias_or_name)
break
def replace_branch_and_sequence_ids_with_cte_name(
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
):
if id.alias_or_name in spark.known_ids:
# Check if we have a join and if both the tables in that join share a common branch id
# If so we need to have this reference the left table by default unless the id is a sequence
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
join_table_aliases = [
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
]
ctes_in_join = [
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
return
for cte in reversed(expression_context.ctes):
if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
_set_alias_name(id, cte.alias_or_name)
return
def _set_alias_name(id: exp.Identifier, name: str):
id.set("this", name)
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
results = []
for value in values:
if isinstance(value, str):
results.append(Column.ensure_col(value).expression)
elif isinstance(value, Column):
results.append(value.expression)
elif isinstance(value, exp.Expression):
results.append(value)
else:
raise ValueError(f"Got an invalid type to normalize: {type(value)}")
return results