79 lines
3.2 KiB
Python
79 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import typing as t
|
|
|
|
from sqlglot import expressions as exp
|
|
from sqlglot.dataframe.sql.column import Column
|
|
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
|
from sqlglot.dialects import Spark
|
|
from sqlglot.helper import ensure_list
|
|
|
|
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
|
|
|
if t.TYPE_CHECKING:
|
|
from sqlglot.dataframe.sql.session import SparkSession
|
|
|
|
|
|
def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
|
|
expr = ensure_list(expr)
|
|
expressions = _ensure_expressions(expr)
|
|
for expression in expressions:
|
|
identifiers = expression.find_all(exp.Identifier)
|
|
for identifier in identifiers:
|
|
Spark.normalize_identifier(identifier)
|
|
replace_alias_name_with_cte_name(spark, expression_context, identifier)
|
|
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
|
|
|
|
|
def replace_alias_name_with_cte_name(
|
|
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
|
|
):
|
|
if id.alias_or_name in spark.name_to_sequence_id_mapping:
|
|
for cte in reversed(expression_context.ctes):
|
|
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
|
|
_set_alias_name(id, cte.alias_or_name)
|
|
break
|
|
|
|
|
|
def replace_branch_and_sequence_ids_with_cte_name(
|
|
spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
|
|
):
|
|
if id.alias_or_name in spark.known_ids:
|
|
# Check if we have a join and if both the tables in that join share a common branch id
|
|
# If so we need to have this reference the left table by default unless the id is a sequence
|
|
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
|
# be common in practice
|
|
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
|
|
join_table_aliases = [
|
|
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
|
|
]
|
|
ctes_in_join = [
|
|
cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
|
|
]
|
|
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
|
|
assert len(ctes_in_join) == 2
|
|
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
|
return
|
|
|
|
for cte in reversed(expression_context.ctes):
|
|
if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
|
|
_set_alias_name(id, cte.alias_or_name)
|
|
return
|
|
|
|
|
|
def _set_alias_name(id: exp.Identifier, name: str):
|
|
id.set("this", name)
|
|
|
|
|
|
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
|
|
results = []
|
|
for value in values:
|
|
if isinstance(value, str):
|
|
results.append(Column.ensure_col(value).expression)
|
|
elif isinstance(value, Column):
|
|
results.append(value.expression)
|
|
elif isinstance(value, exp.Expression):
|
|
results.append(value)
|
|
else:
|
|
raise ValueError(f"Got an invalid type to normalize: {type(value)}")
|
|
return results
|