sqlglot.optimizer.qualify_tables
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import DialectType 9from sqlglot.helper import csv_reader, name_sequence 10from sqlglot.optimizer.scope import Scope, traverse_scope 11from sqlglot.schema import Schema 12 13 14def qualify_tables( 15 expression: E, 16 db: t.Optional[str | exp.Identifier] = None, 17 catalog: t.Optional[str | exp.Identifier] = None, 18 schema: t.Optional[Schema] = None, 19 dialect: DialectType = None, 20) -> E: 21 """ 22 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 23 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 24 25 Examples: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 28 >>> qualify_tables(expression, db="db").sql() 29 'SELECT 1 FROM db.tbl AS tbl' 30 >>> 31 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 32 >>> qualify_tables(expression).sql() 33 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 34 35 Args: 36 expression: Expression to qualify 37 db: Database name 38 catalog: Catalog name 39 schema: A schema to populate 40 dialect: The dialect to parse catalog and schema into. 41 42 Returns: 43 The qualified expression. 44 """ 45 next_alias_name = name_sequence("_q_") 46 db = exp.parse_identifier(db, dialect=dialect) if db else None 47 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 48 49 def _qualify(table: exp.Table) -> None: 50 if isinstance(table.this, exp.Identifier): 51 if not table.args.get("db"): 52 table.set("db", db) 53 if not table.args.get("catalog") and table.args.get("db"): 54 table.set("catalog", catalog) 55 56 if not isinstance(expression, exp.Subqueryable): 57 for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): 58 if isinstance(node, exp.Table): 59 _qualify(node) 60 61 for scope in traverse_scope(expression): 62 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 63 if isinstance(derived_table, exp.Subquery): 64 unnested = derived_table.unnest() 65 if isinstance(unnested, exp.Table): 66 joins = unnested.args.pop("joins", None) 67 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 68 derived_table.this.set("joins", joins) 69 70 if not derived_table.args.get("alias"): 71 alias_ = next_alias_name() 72 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 73 scope.rename_source(None, alias_) 74 75 pivots = derived_table.args.get("pivots") 76 if pivots and not pivots[0].alias: 77 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 78 79 for name, source in scope.sources.items(): 80 if isinstance(source, exp.Table): 81 _qualify(source) 82 83 pivots = pivots = source.args.get("pivots") 84 if not source.alias: 85 # Don't add the pivot's alias to the pivoted table, use the table's name instead 86 if pivots and pivots[0].alias == name: 87 name = source.name 88 89 # Mutates the source by attaching an alias to it 90 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 91 92 if pivots and not pivots[0].alias: 93 pivots[0].set( 94 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 95 ) 96 97 if schema and isinstance(source.this, exp.ReadCSV): 98 with csv_reader(source.this) as reader: 99 header = next(reader) 100 columns = next(reader) 101 schema.add_table( 102 source, 103 {k: type(v).__name__ for k, v in zip(header, columns)}, 104 match_depth=False, 105 ) 106 elif isinstance(source, Scope) and source.is_udtf: 107 udtf = source.expression 108 table_alias = udtf.args.get("alias") or exp.TableAlias( 109 this=exp.to_identifier(next_alias_name()) 110 ) 111 udtf.set("alias", table_alias) 112 113 if not table_alias.name: 114 table_alias.set("this", exp.to_identifier(next_alias_name())) 115 if isinstance(udtf, exp.Values) and not table_alias.columns: 116 for i, e in enumerate(udtf.expressions[0].expressions): 117 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 118 else: 119 for node, parent, _ in scope.walk(): 120 if ( 121 isinstance(node, exp.Table) 122 and not node.alias 123 and isinstance(parent, (exp.From, exp.Join)) 124 ): 125 # Mutates the table by attaching an alias to it 126 alias(node, node.name, copy=False, table=True) 127 128 return expression
def
qualify_tables( expression: ~E, db: Union[sqlglot.expressions.Identifier, str, NoneType] = None, catalog: Union[sqlglot.expressions.Identifier, str, NoneType] = None, schema: Optional[sqlglot.schema.Schema] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> ~E:
15def qualify_tables( 16 expression: E, 17 db: t.Optional[str | exp.Identifier] = None, 18 catalog: t.Optional[str | exp.Identifier] = None, 19 schema: t.Optional[Schema] = None, 20 dialect: DialectType = None, 21) -> E: 22 """ 23 Rewrite sqlglot AST to have fully qualified tables. Join constructs such as 24 (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. 25 26 Examples: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 29 >>> qualify_tables(expression, db="db").sql() 30 'SELECT 1 FROM db.tbl AS tbl' 31 >>> 32 >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") 33 >>> qualify_tables(expression).sql() 34 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' 35 36 Args: 37 expression: Expression to qualify 38 db: Database name 39 catalog: Catalog name 40 schema: A schema to populate 41 dialect: The dialect to parse catalog and schema into. 42 43 Returns: 44 The qualified expression. 45 """ 46 next_alias_name = name_sequence("_q_") 47 db = exp.parse_identifier(db, dialect=dialect) if db else None 48 catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None 49 50 def _qualify(table: exp.Table) -> None: 51 if isinstance(table.this, exp.Identifier): 52 if not table.args.get("db"): 53 table.set("db", db) 54 if not table.args.get("catalog") and table.args.get("db"): 55 table.set("catalog", catalog) 56 57 if not isinstance(expression, exp.Subqueryable): 58 for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): 59 if isinstance(node, exp.Table): 60 _qualify(node) 61 62 for scope in traverse_scope(expression): 63 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 64 if isinstance(derived_table, exp.Subquery): 65 unnested = derived_table.unnest() 66 if isinstance(unnested, exp.Table): 67 joins = unnested.args.pop("joins", None) 68 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 69 derived_table.this.set("joins", joins) 70 71 if not derived_table.args.get("alias"): 72 alias_ = next_alias_name() 73 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 74 scope.rename_source(None, alias_) 75 76 pivots = derived_table.args.get("pivots") 77 if pivots and not pivots[0].alias: 78 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) 79 80 for name, source in scope.sources.items(): 81 if isinstance(source, exp.Table): 82 _qualify(source) 83 84 pivots = pivots = source.args.get("pivots") 85 if not source.alias: 86 # Don't add the pivot's alias to the pivoted table, use the table's name instead 87 if pivots and pivots[0].alias == name: 88 name = source.name 89 90 # Mutates the source by attaching an alias to it 91 alias(source, name or source.name or next_alias_name(), copy=False, table=True) 92 93 if pivots and not pivots[0].alias: 94 pivots[0].set( 95 "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) 96 ) 97 98 if schema and isinstance(source.this, exp.ReadCSV): 99 with csv_reader(source.this) as reader: 100 header = next(reader) 101 columns = next(reader) 102 schema.add_table( 103 source, 104 {k: type(v).__name__ for k, v in zip(header, columns)}, 105 match_depth=False, 106 ) 107 elif isinstance(source, Scope) and source.is_udtf: 108 udtf = source.expression 109 table_alias = udtf.args.get("alias") or exp.TableAlias( 110 this=exp.to_identifier(next_alias_name()) 111 ) 112 udtf.set("alias", table_alias) 113 114 if not table_alias.name: 115 table_alias.set("this", exp.to_identifier(next_alias_name())) 116 if isinstance(udtf, exp.Values) and not table_alias.columns: 117 for i, e in enumerate(udtf.expressions[0].expressions): 118 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 119 else: 120 for node, parent, _ in scope.walk(): 121 if ( 122 isinstance(node, exp.Table) 123 and not node.alias 124 and isinstance(parent, (exp.From, exp.Join)) 125 ): 126 # Mutates the table by attaching an alias to it 127 alias(node, node.name, copy=False, table=True) 128 129 return expression
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t'
Arguments:
- expression: Expression to qualify
- db: Database name
- catalog: Catalog name
- schema: A schema to populate
- dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.