1
0
Fork 0

Merging upstream version 22.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:29:39 +01:00
parent b13ba670fd
commit 2c28c49d7e
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
148 changed files with 68457 additions and 63176 deletions

View file

@ -1,16 +1,19 @@
from __future__ import annotations
import json
import logging
import typing as t
from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
logger = logging.getLogger("sqlglot")
@dataclass(frozen=True)
class Node:
@ -18,7 +21,8 @@ class Node:
expression: exp.Expression
source: exp.Expression
downstream: t.List[Node] = field(default_factory=list)
alias: str = ""
source_name: str = ""
reference_node_name: str = ""
def walk(self) -> t.Iterator[Node]:
yield self
@ -67,7 +71,7 @@ def lineage(
column: str | exp.Column,
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
dialect: DialectType = None,
**kwargs,
) -> Node:
@ -86,14 +90,12 @@ def lineage(
"""
expression = maybe_parse(sql, dialect=dialect)
column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
if sources:
expression = exp.expand(
expression,
{
k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
for k, v in sources.items()
},
{k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
dialect=dialect,
)
@ -109,122 +111,141 @@ def lineage(
if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
def to_node(
column: str | int,
scope: Scope,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
alias: t.Optional[str] = None,
) -> Node:
aliases = {
dt.alias: dt.comments[0].split()[1]
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
if not any(select.alias_or_name == column for select in scope.expression.selects):
raise SqlglotError(f"Cannot find column '{column}' in query.")
# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
select = (
scope.expression.selects[column]
return to_node(column, scope, dialect)
def to_node(
column: str | int,
scope: Scope,
dialect: DialectType,
scope_name: t.Optional[str] = None,
upstream: t.Optional[Node] = None,
source_name: t.Optional[str] = None,
reference_node_name: t.Optional[str] = None,
) -> Node:
source_names = {
dt.alias: dt.comments[0].split()[1]
for dt in scope.derived_tables
if dt.comments and dt.comments[0].startswith("source: ")
}
# Find the specific select clause that is the source of the column we want.
# This can either be a specific, named select or a generic `*` clause.
select = (
scope.expression.selects[column]
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
exp.Star() if scope.expression.is_star else scope.expression,
)
)
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
index = (
column
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
exp.Star() if scope.expression.is_star else scope.expression,
(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column or select.is_star
),
-1, # mypy will not allow a None here, but a negative index should never be returned
)
)
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
if index == -1:
raise ValueError(f"Could not find {column} in {scope.expression}")
index = (
column
if isinstance(column, int)
else next(
(
i
for i, select in enumerate(scope.expression.selects)
if select.alias_or_name == column or select.is_star
),
-1, # mypy will not allow a None here, but a negative index should never be returned
)
for s in scope.union_scopes:
to_node(
index,
scope=s,
dialect=dialect,
upstream=upstream,
source_name=source_name,
reference_node_name=reference_node_name,
)
if index == -1:
raise ValueError(f"Could not find {column} in {scope.expression}")
return upstream
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream, alias=alias)
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
else:
source = scope.expression
return upstream
# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column}" if scope_name else str(column),
source=source,
expression=select,
source_name=source_name or "",
reference_node_name=reference_node_name or "",
)
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
# "x", SELECT x, y FROM foo
# => "x", SELECT x FROM foo
source = t.cast(exp.Expression, scope.expression.select(select, append=False))
else:
source = scope.expression
if upstream:
upstream.downstream.append(node)
# Create the node for this step in the lineage chain, and attach it to the previous one.
node = Node(
name=f"{scope_name}.{column}" if scope_name else str(column),
source=source,
expression=select,
alias=alias or "",
)
subquery_scopes = {
id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
}
if upstream:
upstream.downstream.append(node)
for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
subquery_scope = subquery_scopes.get(id(subquery))
if not subquery_scope:
logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
continue
subquery_scopes = {
id(subquery_scope.expression): subquery_scope
for subquery_scope in scope.subquery_scopes
}
for subquery in find_all_in_scope(select, exp.Subqueryable):
subquery_scope = subquery_scopes[id(subquery)]
for name in subquery.named_selects:
to_node(name, scope=subquery_scope, upstream=node)
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
if isinstance(source, Scope):
source = source.expression
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.
source_columns = set(find_all_in_scope(select, exp.Column))
# If the source is a UDTF find columns used in the UTDF to generate the table
if isinstance(source, exp.UDTF):
source_columns |= set(source.find_all(exp.Column))
for c in source_columns:
table = c.table
source = scope.sources.get(table)
for name in subquery.named_selects:
to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
if isinstance(source, Scope):
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
scope=source,
scope_name=table,
upstream=node,
alias=aliases.get(table) or alias,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
source = source.expression
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
return node
# Find all columns that went into creating this one to list their lineage nodes.
source_columns = set(find_all_in_scope(select, exp.Column))
return to_node(column if isinstance(column, str) else column.name, scope)
# If the source is a UDTF find columns used in the UTDF to generate the table
if isinstance(source, exp.UDTF):
source_columns |= set(source.find_all(exp.Column))
for c in source_columns:
table = c.table
source = scope.sources.get(table)
if isinstance(source, Scope):
selected_node, _ = scope.selected_sources.get(table, (None, None))
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
to_node(
c.name,
scope=source,
dialect=dialect,
scope_name=table,
upstream=node,
source_name=source_names.get(table) or source_name,
reference_node_name=selected_node.name if selected_node else None,
)
else:
# The source is not a scope - we've reached the end of the line. At this point, if a source is not found
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
return node
class GraphHTML: