Edit on GitHub

sqlglot.optimizer.qualify_tables

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import DialectType
  9from sqlglot.helper import csv_reader, name_sequence
 10from sqlglot.optimizer.scope import Scope, traverse_scope
 11from sqlglot.schema import Schema
 12
 13
 14def qualify_tables(
 15    expression: E,
 16    db: t.Optional[str | exp.Identifier] = None,
 17    catalog: t.Optional[str | exp.Identifier] = None,
 18    schema: t.Optional[Schema] = None,
 19    dialect: DialectType = None,
 20) -> E:
 21    """
 22    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 23    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 24
 25    Examples:
 26        >>> import sqlglot
 27        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 28        >>> qualify_tables(expression, db="db").sql()
 29        'SELECT 1 FROM db.tbl AS tbl'
 30        >>>
 31        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 32        >>> qualify_tables(expression).sql()
 33        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 34
 35    Args:
 36        expression: Expression to qualify
 37        db: Database name
 38        catalog: Catalog name
 39        schema: A schema to populate
 40        dialect: The dialect to parse catalog and schema into.
 41
 42    Returns:
 43        The qualified expression.
 44    """
 45    next_alias_name = name_sequence("_q_")
 46    db = exp.parse_identifier(db, dialect=dialect) if db else None
 47    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 48
 49    def _qualify(table: exp.Table) -> None:
 50        if isinstance(table.this, exp.Identifier):
 51            if not table.args.get("db"):
 52                table.set("db", db)
 53            if not table.args.get("catalog") and table.args.get("db"):
 54                table.set("catalog", catalog)
 55
 56    if not isinstance(expression, exp.Subqueryable):
 57        for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
 58            if isinstance(node, exp.Table):
 59                _qualify(node)
 60
 61    for scope in traverse_scope(expression):
 62        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 63            if isinstance(derived_table, exp.Subquery):
 64                unnested = derived_table.unnest()
 65                if isinstance(unnested, exp.Table):
 66                    joins = unnested.args.pop("joins", None)
 67                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 68                    derived_table.this.set("joins", joins)
 69
 70            if not derived_table.args.get("alias"):
 71                alias_ = next_alias_name()
 72                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 73                scope.rename_source(None, alias_)
 74
 75            pivots = derived_table.args.get("pivots")
 76            if pivots and not pivots[0].alias:
 77                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 78
 79        for name, source in scope.sources.items():
 80            if isinstance(source, exp.Table):
 81                _qualify(source)
 82
 83                pivots = pivots = source.args.get("pivots")
 84                if not source.alias:
 85                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 86                    if pivots and pivots[0].alias == name:
 87                        name = source.name
 88
 89                    # Mutates the source by attaching an alias to it
 90                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 91
 92                if pivots and not pivots[0].alias:
 93                    pivots[0].set(
 94                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
 95                    )
 96
 97                if schema and isinstance(source.this, exp.ReadCSV):
 98                    with csv_reader(source.this) as reader:
 99                        header = next(reader)
100                        columns = next(reader)
101                        schema.add_table(
102                            source,
103                            {k: type(v).__name__ for k, v in zip(header, columns)},
104                            match_depth=False,
105                        )
106            elif isinstance(source, Scope) and source.is_udtf:
107                udtf = source.expression
108                table_alias = udtf.args.get("alias") or exp.TableAlias(
109                    this=exp.to_identifier(next_alias_name())
110                )
111                udtf.set("alias", table_alias)
112
113                if not table_alias.name:
114                    table_alias.set("this", exp.to_identifier(next_alias_name()))
115                if isinstance(udtf, exp.Values) and not table_alias.columns:
116                    for i, e in enumerate(udtf.expressions[0].expressions):
117                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
118            else:
119                for node, parent, _ in scope.walk():
120                    if (
121                        isinstance(node, exp.Table)
122                        and not node.alias
123                        and isinstance(parent, (exp.From, exp.Join))
124                    ):
125                        # Mutates the table by attaching an alias to it
126                        alias(node, node.name, copy=False, table=True)
127
128    return expression
def qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> ~E:
 15def qualify_tables(
 16    expression: E,
 17    db: t.Optional[str | exp.Identifier] = None,
 18    catalog: t.Optional[str | exp.Identifier] = None,
 19    schema: t.Optional[Schema] = None,
 20    dialect: DialectType = None,
 21) -> E:
 22    """
 23    Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
 24    (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
 25
 26    Examples:
 27        >>> import sqlglot
 28        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
 29        >>> qualify_tables(expression, db="db").sql()
 30        'SELECT 1 FROM db.tbl AS tbl'
 31        >>>
 32        >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
 33        >>> qualify_tables(expression).sql()
 34        'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
 35
 36    Args:
 37        expression: Expression to qualify
 38        db: Database name
 39        catalog: Catalog name
 40        schema: A schema to populate
 41        dialect: The dialect to parse catalog and schema into.
 42
 43    Returns:
 44        The qualified expression.
 45    """
 46    next_alias_name = name_sequence("_q_")
 47    db = exp.parse_identifier(db, dialect=dialect) if db else None
 48    catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
 49
 50    def _qualify(table: exp.Table) -> None:
 51        if isinstance(table.this, exp.Identifier):
 52            if not table.args.get("db"):
 53                table.set("db", db)
 54            if not table.args.get("catalog") and table.args.get("db"):
 55                table.set("catalog", catalog)
 56
 57    if not isinstance(expression, exp.Subqueryable):
 58        for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)):
 59            if isinstance(node, exp.Table):
 60                _qualify(node)
 61
 62    for scope in traverse_scope(expression):
 63        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
 64            if isinstance(derived_table, exp.Subquery):
 65                unnested = derived_table.unnest()
 66                if isinstance(unnested, exp.Table):
 67                    joins = unnested.args.pop("joins", None)
 68                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
 69                    derived_table.this.set("joins", joins)
 70
 71            if not derived_table.args.get("alias"):
 72                alias_ = next_alias_name()
 73                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
 74                scope.rename_source(None, alias_)
 75
 76            pivots = derived_table.args.get("pivots")
 77            if pivots and not pivots[0].alias:
 78                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())))
 79
 80        for name, source in scope.sources.items():
 81            if isinstance(source, exp.Table):
 82                _qualify(source)
 83
 84                pivots = pivots = source.args.get("pivots")
 85                if not source.alias:
 86                    # Don't add the pivot's alias to the pivoted table, use the table's name instead
 87                    if pivots and pivots[0].alias == name:
 88                        name = source.name
 89
 90                    # Mutates the source by attaching an alias to it
 91                    alias(source, name or source.name or next_alias_name(), copy=False, table=True)
 92
 93                if pivots and not pivots[0].alias:
 94                    pivots[0].set(
 95                        "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
 96                    )
 97
 98                if schema and isinstance(source.this, exp.ReadCSV):
 99                    with csv_reader(source.this) as reader:
100                        header = next(reader)
101                        columns = next(reader)
102                        schema.add_table(
103                            source,
104                            {k: type(v).__name__ for k, v in zip(header, columns)},
105                            match_depth=False,
106                        )
107            elif isinstance(source, Scope) and source.is_udtf:
108                udtf = source.expression
109                table_alias = udtf.args.get("alias") or exp.TableAlias(
110                    this=exp.to_identifier(next_alias_name())
111                )
112                udtf.set("alias", table_alias)
113
114                if not table_alias.name:
115                    table_alias.set("this", exp.to_identifier(next_alias_name()))
116                if isinstance(udtf, exp.Values) and not table_alias.columns:
117                    for i, e in enumerate(udtf.expressions[0].expressions):
118                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
119            else:
120                for node, parent, _ in scope.walk():
121                    if (
122                        isinstance(node, exp.Table)
123                        and not node.alias
124                        and isinstance(parent, (exp.From, exp.Join))
125                    ):
126                        # Mutates the table by attaching an alias to it
127                        alias(node, node.name, copy=False, table=True)
128
129    return expression

Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t")
>>> qualify_tables(expression).sql()
'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
  • expression: Expression to qualify
  • db: Database name
  • catalog: Catalog name
  • schema: A schema to populate
  • dialect: The dialect to parse catalog and schema into.
Returns:

The qualified expression.