1
0
Fork 0

Merging upstream version 26.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:59:10 +01:00
parent e2fd836612
commit 63d24513e5
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
65 changed files with 45416 additions and 44542 deletions

View file

@ -254,6 +254,27 @@ def to_node(
if dt.comments and dt.comments[0].startswith("source: ")
}
pivots = scope.pivots
pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None
if pivot:
# For each aggregation function, the pivot creates a new column for each field in category
# combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a,
# b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum'
# belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs
# to the column indices 1, 3. Here, only the columns used in the aggregations are of interest
# in the lineage, so lookup the pivot column name by index and map that with the columns used
# in the aggregation.
#
# Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b')
pivot_columns = pivot.args["columns"]
pivot_aggs_count = len(pivot.expressions)
pivot_column_mapping = {}
for i, agg in enumerate(pivot.expressions):
agg_cols = list(agg.find_all(exp.Column))
for col_index in range(i, len(pivot_columns), pivot_aggs_count):
pivot_column_mapping[pivot_columns[col_index].name] = agg_cols
for c in source_columns:
table = c.table
source = scope.sources.get(table)
@ -265,6 +286,7 @@ def to_node(
elif source.scope_type == ScopeType.CTE:
selected_node, _ = scope.selected_sources.get(table, (None, None))
reference_node_name = selected_node.name if selected_node else None
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
@ -276,10 +298,45 @@ def to_node(
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
elif pivot and pivot.alias_or_name == c.table:
downstream_columns = []
column_name = c.name
if any(column_name == pivot_column.name for pivot_column in pivot_columns):
downstream_columns.extend(pivot_column_mapping[column_name])
else:
# The column is not in the pivot, so it must be an implicit column of the
# pivoted source -- adapt column to be from the implicit pivoted source.
downstream_columns.append(exp.column(c.this, table=pivot.parent.this))
for downstream_column in downstream_columns:
table = downstream_column.table
source = scope.sources.get(table)
if isinstance(source, Scope):
to_node(
downstream_column.name,
scope=source,
scope_name=table,
dialect=dialect,
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
else:
source = source or exp.Placeholder()
node.downstream.append(
Node(
name=downstream_column.sql(comments=False),
source=source,
expression=source,
)
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
# The source is not a scope and the column is not in any pivot - we've reached the end
# of the line. At this point, if a source is not found it means this column's lineage
# is unknown. This can happen if the definition of a source used in a query is not
# passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(
Node(name=c.sql(comments=False), source=source, expression=source)