Edit on GitHub

sqlglot.optimizer.qualify_tables

 1import itertools
 2
 3from sqlglot import alias, exp
 4from sqlglot.helper import csv_reader
 5from sqlglot.optimizer.scope import Scope, traverse_scope
 6
 7
 8def qualify_tables(expression, db=None, catalog=None, schema=None):
 9    """
10    Rewrite sqlglot AST to have fully qualified tables. Additionally, this
11    replaces "join constructs" (*) by equivalent SELECT * subqueries.
12
13    Examples:
14        >>> import sqlglot
15        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
16        >>> qualify_tables(expression, db="db").sql()
17        'SELECT 1 FROM db.tbl AS tbl'
18        >>>
19        >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
20        >>> qualify_tables(expression).sql()
21        'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
22
23    Args:
24        expression (sqlglot.Expression): expression to qualify
25        db (str): Database name
26        catalog (str): Catalog name
27        schema: A schema to populate
28
29    Returns:
30        sqlglot.Expression: qualified expression
31
32    (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
33    """
34    sequence = itertools.count()
35
36    next_name = lambda: f"_q_{next(sequence)}"
37
38    for scope in traverse_scope(expression):
39        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
40            # Expand join construct
41            if isinstance(derived_table, exp.Subquery):
42                unnested = derived_table.unnest()
43                if isinstance(unnested, exp.Table):
44                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
45
46            if not derived_table.args.get("alias"):
47                alias_ = next_name()
48                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
49                scope.rename_source(None, alias_)
50
51            pivots = derived_table.args.get("pivots")
52            if pivots and not pivots[0].alias:
53                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
54
55        for name, source in scope.sources.items():
56            if isinstance(source, exp.Table):
57                if isinstance(source.this, exp.Identifier):
58                    if not source.args.get("db"):
59                        source.set("db", exp.to_identifier(db))
60                    if not source.args.get("catalog"):
61                        source.set("catalog", exp.to_identifier(catalog))
62
63                if not source.alias:
64                    source = source.replace(
65                        alias(
66                            source,
67                            name or source.name or next_name(),
68                            copy=True,
69                            table=True,
70                        )
71                    )
72
73                pivots = source.args.get("pivots")
74                if pivots and not pivots[0].alias:
75                    pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
76
77                if schema and isinstance(source.this, exp.ReadCSV):
78                    with csv_reader(source.this) as reader:
79                        header = next(reader)
80                        columns = next(reader)
81                        schema.add_table(
82                            source, {k: type(v).__name__ for k, v in zip(header, columns)}
83                        )
84            elif isinstance(source, Scope) and source.is_udtf:
85                udtf = source.expression
86                table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
87                udtf.set("alias", table_alias)
88
89                if not table_alias.name:
90                    table_alias.set("this", next_name())
91                if isinstance(udtf, exp.Values) and not table_alias.columns:
92                    for i, e in enumerate(udtf.expressions[0].expressions):
93                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
94
95    return expression
def qualify_tables(expression, db=None, catalog=None, schema=None):
 9def qualify_tables(expression, db=None, catalog=None, schema=None):
10    """
11    Rewrite sqlglot AST to have fully qualified tables. Additionally, this
12    replaces "join constructs" (*) by equivalent SELECT * subqueries.
13
14    Examples:
15        >>> import sqlglot
16        >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
17        >>> qualify_tables(expression, db="db").sql()
18        'SELECT 1 FROM db.tbl AS tbl'
19        >>>
20        >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
21        >>> qualify_tables(expression).sql()
22        'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
23
24    Args:
25        expression (sqlglot.Expression): expression to qualify
26        db (str): Database name
27        catalog (str): Catalog name
28        schema: A schema to populate
29
30    Returns:
31        sqlglot.Expression: qualified expression
32
33    (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
34    """
35    sequence = itertools.count()
36
37    next_name = lambda: f"_q_{next(sequence)}"
38
39    for scope in traverse_scope(expression):
40        for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
41            # Expand join construct
42            if isinstance(derived_table, exp.Subquery):
43                unnested = derived_table.unnest()
44                if isinstance(unnested, exp.Table):
45                    derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
46
47            if not derived_table.args.get("alias"):
48                alias_ = next_name()
49                derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
50                scope.rename_source(None, alias_)
51
52            pivots = derived_table.args.get("pivots")
53            if pivots and not pivots[0].alias:
54                pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
55
56        for name, source in scope.sources.items():
57            if isinstance(source, exp.Table):
58                if isinstance(source.this, exp.Identifier):
59                    if not source.args.get("db"):
60                        source.set("db", exp.to_identifier(db))
61                    if not source.args.get("catalog"):
62                        source.set("catalog", exp.to_identifier(catalog))
63
64                if not source.alias:
65                    source = source.replace(
66                        alias(
67                            source,
68                            name or source.name or next_name(),
69                            copy=True,
70                            table=True,
71                        )
72                    )
73
74                pivots = source.args.get("pivots")
75                if pivots and not pivots[0].alias:
76                    pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))
77
78                if schema and isinstance(source.this, exp.ReadCSV):
79                    with csv_reader(source.this) as reader:
80                        header = next(reader)
81                        columns = next(reader)
82                        schema.add_table(
83                            source, {k: type(v).__name__ for k, v in zip(header, columns)}
84                        )
85            elif isinstance(source, Scope) and source.is_udtf:
86                udtf = source.expression
87                table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
88                udtf.set("alias", table_alias)
89
90                if not table_alias.name:
91                    table_alias.set("this", next_name())
92                if isinstance(udtf, exp.Values) and not table_alias.columns:
93                    for i, e in enumerate(udtf.expressions[0].expressions):
94                        table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
95
96    return expression

Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • db (str): Database name
  • catalog (str): Catalog name
  • schema: A schema to populate
Returns:

sqlglot.Expression: qualified expression

(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html