1
0
Fork 0

Adding upstream version 20.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:16:46 +01:00
parent 6a89523da4
commit 5bd573dda1
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
127 changed files with 73384 additions and 73067 deletions

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
@ -10,9 +13,10 @@ from sqlglot.schema import Schema
def qualify_tables(
expression: E,
db: t.Optional[str] = None,
catalog: t.Optional[str] = None,
db: t.Optional[str | exp.Identifier] = None,
catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@ -33,11 +37,14 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.
"""
next_alias_name = name_sequence("_q_")
db = exp.parse_identifier(db, dialect=dialect) if db else None
catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@ -61,9 +68,9 @@ def qualify_tables(
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", exp.to_identifier(catalog))
source.set("catalog", catalog)
if not source.alias:
# Mutates the source by attaching an alias to it