sqlglot.optimizer.qualify_tables
1import itertools 2 3from sqlglot import alias, exp 4from sqlglot.helper import csv_reader 5from sqlglot.optimizer.scope import Scope, traverse_scope 6 7 8def qualify_tables(expression, db=None, catalog=None, schema=None): 9 """ 10 Rewrite sqlglot AST to have fully qualified tables. Additionally, this 11 replaces "join constructs" (*) by equivalent SELECT * subqueries. 12 13 Examples: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 16 >>> qualify_tables(expression, db="db").sql() 17 'SELECT 1 FROM db.tbl AS tbl' 18 >>> 19 >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") 20 >>> qualify_tables(expression).sql() 21 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' 22 23 Args: 24 expression (sqlglot.Expression): expression to qualify 25 db (str): Database name 26 catalog (str): Catalog name 27 schema: A schema to populate 28 29 Returns: 30 sqlglot.Expression: qualified expression 31 32 (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html 33 """ 34 sequence = itertools.count() 35 36 next_name = lambda: f"_q_{next(sequence)}" 37 38 for scope in traverse_scope(expression): 39 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 40 # Expand join construct 41 if isinstance(derived_table, exp.Subquery): 42 unnested = derived_table.unnest() 43 if isinstance(unnested, exp.Table): 44 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 45 46 if not derived_table.args.get("alias"): 47 alias_ = next_name() 48 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 49 scope.rename_source(None, alias_) 50 51 pivots = derived_table.args.get("pivots") 52 if pivots and not pivots[0].alias: 53 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name()))) 54 55 for name, source in scope.sources.items(): 56 if isinstance(source, exp.Table): 57 if isinstance(source.this, exp.Identifier): 58 if not source.args.get("db"): 59 source.set("db", exp.to_identifier(db)) 60 if not source.args.get("catalog"): 61 source.set("catalog", exp.to_identifier(catalog)) 62 63 if not source.alias: 64 source = source.replace( 65 alias( 66 source, 67 name or source.name or next_name(), 68 copy=True, 69 table=True, 70 ) 71 ) 72 73 pivots = source.args.get("pivots") 74 if pivots and not pivots[0].alias: 75 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name()))) 76 77 if schema and isinstance(source.this, exp.ReadCSV): 78 with csv_reader(source.this) as reader: 79 header = next(reader) 80 columns = next(reader) 81 schema.add_table( 82 source, {k: type(v).__name__ for k, v in zip(header, columns)} 83 ) 84 elif isinstance(source, Scope) and source.is_udtf: 85 udtf = source.expression 86 table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) 87 udtf.set("alias", table_alias) 88 89 if not table_alias.name: 90 table_alias.set("this", next_name()) 91 if isinstance(udtf, exp.Values) and not table_alias.columns: 92 for i, e in enumerate(udtf.expressions[0].expressions): 93 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 94 95 return expression
def
qualify_tables(expression, db=None, catalog=None, schema=None):
9def qualify_tables(expression, db=None, catalog=None, schema=None): 10 """ 11 Rewrite sqlglot AST to have fully qualified tables. Additionally, this 12 replaces "join constructs" (*) by equivalent SELECT * subqueries. 13 14 Examples: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") 17 >>> qualify_tables(expression, db="db").sql() 18 'SELECT 1 FROM db.tbl AS tbl' 19 >>> 20 >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") 21 >>> qualify_tables(expression).sql() 22 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' 23 24 Args: 25 expression (sqlglot.Expression): expression to qualify 26 db (str): Database name 27 catalog (str): Catalog name 28 schema: A schema to populate 29 30 Returns: 31 sqlglot.Expression: qualified expression 32 33 (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html 34 """ 35 sequence = itertools.count() 36 37 next_name = lambda: f"_q_{next(sequence)}" 38 39 for scope in traverse_scope(expression): 40 for derived_table in itertools.chain(scope.ctes, scope.derived_tables): 41 # Expand join construct 42 if isinstance(derived_table, exp.Subquery): 43 unnested = derived_table.unnest() 44 if isinstance(unnested, exp.Table): 45 derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) 46 47 if not derived_table.args.get("alias"): 48 alias_ = next_name() 49 derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) 50 scope.rename_source(None, alias_) 51 52 pivots = derived_table.args.get("pivots") 53 if pivots and not pivots[0].alias: 54 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name()))) 55 56 for name, source in scope.sources.items(): 57 if isinstance(source, exp.Table): 58 if isinstance(source.this, exp.Identifier): 59 if not source.args.get("db"): 60 source.set("db", exp.to_identifier(db)) 61 if not source.args.get("catalog"): 62 source.set("catalog", exp.to_identifier(catalog)) 63 64 if not source.alias: 65 source = source.replace( 66 alias( 67 source, 68 name or source.name or next_name(), 69 copy=True, 70 table=True, 71 ) 72 ) 73 74 pivots = source.args.get("pivots") 75 if pivots and not pivots[0].alias: 76 pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name()))) 77 78 if schema and isinstance(source.this, exp.ReadCSV): 79 with csv_reader(source.this) as reader: 80 header = next(reader) 81 columns = next(reader) 82 schema.add_table( 83 source, {k: type(v).__name__ for k, v in zip(header, columns)} 84 ) 85 elif isinstance(source, Scope) and source.is_udtf: 86 udtf = source.expression 87 table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) 88 udtf.set("alias", table_alias) 89 90 if not table_alias.name: 91 table_alias.set("this", next_name()) 92 if isinstance(udtf, exp.Values) and not table_alias.columns: 93 for i, e in enumerate(udtf.expressions[0].expressions): 94 table_alias.append("columns", exp.to_identifier(f"_col_{i}")) 95 96 return expression
Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") >>> qualify_tables(expression).sql() 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- db (str): Database name
- catalog (str): Catalog name
- schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
(*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html