sqlglot.optimizer.qualify_columns
1import itertools 2import typing as t 3 4from sqlglot import alias, exp 5from sqlglot.errors import OptimizeError 6from sqlglot.optimizer.scope import Scope, traverse_scope 7from sqlglot.schema import ensure_schema 8 9 10def qualify_columns(expression, schema): 11 """ 12 Rewrite sqlglot AST to have fully qualified columns. 13 14 Example: 15 >>> import sqlglot 16 >>> schema = {"tbl": {"col": "INT"}} 17 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 18 >>> qualify_columns(expression, schema).sql() 19 'SELECT tbl.col AS col FROM tbl' 20 21 Args: 22 expression (sqlglot.Expression): expression to qualify 23 schema (dict|sqlglot.optimizer.Schema): Database schema 24 Returns: 25 sqlglot.Expression: qualified expression 26 """ 27 schema = ensure_schema(schema) 28 29 for scope in traverse_scope(expression): 30 resolver = Resolver(scope, schema) 31 _pop_table_column_aliases(scope.ctes) 32 _pop_table_column_aliases(scope.derived_tables) 33 _expand_using(scope, resolver) 34 _qualify_columns(scope, resolver) 35 if not isinstance(scope.expression, exp.UDTF): 36 _expand_stars(scope, resolver) 37 _qualify_outputs(scope) 38 _expand_group_by(scope, resolver) 39 _expand_order_by(scope) 40 41 return expression 42 43 44def validate_qualify_columns(expression): 45 """Raise an `OptimizeError` if any columns aren't qualified""" 46 unqualified_columns = [] 47 for scope in traverse_scope(expression): 48 if isinstance(scope.expression, exp.Select): 49 unqualified_columns.extend(scope.unqualified_columns) 50 if scope.external_columns and not scope.is_correlated_subquery: 51 column = scope.external_columns[0] 52 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 53 54 if unqualified_columns: 55 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 56 return expression 57 58 59def _pop_table_column_aliases(derived_tables): 60 """ 61 Remove table column aliases. 62 63 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 64 """ 65 for derived_table in derived_tables: 66 table_alias = derived_table.args.get("alias") 67 if table_alias: 68 table_alias.args.pop("columns", None) 69 70 71def _expand_using(scope, resolver): 72 joins = list(scope.expression.find_all(exp.Join)) 73 names = {join.this.alias for join in joins} 74 ordered = [key for key in scope.selected_sources if key not in names] 75 76 # Mapping of automatically joined column names to source names 77 column_tables = {} 78 79 for join in joins: 80 using = join.args.get("using") 81 82 if not using: 83 continue 84 85 join_table = join.this.alias_or_name 86 87 columns = {} 88 89 for k in scope.selected_sources: 90 if k in ordered: 91 for column in resolver.get_source_columns(k): 92 if column not in columns: 93 columns[column] = k 94 95 ordered.append(join_table) 96 join_columns = resolver.get_source_columns(join_table) 97 conditions = [] 98 99 for identifier in using: 100 identifier = identifier.name 101 table = columns.get(identifier) 102 103 if not table or identifier not in join_columns: 104 raise OptimizeError(f"Cannot automatically join: {identifier}") 105 106 conditions.append( 107 exp.condition( 108 exp.EQ( 109 this=exp.column(identifier, table=table), 110 expression=exp.column(identifier, table=join_table), 111 ) 112 ) 113 ) 114 115 tables = column_tables.setdefault(identifier, []) 116 if table not in tables: 117 tables.append(table) 118 if join_table not in tables: 119 tables.append(join_table) 120 121 join.args.pop("using") 122 join.set("on", exp.and_(*conditions)) 123 124 if column_tables: 125 for column in scope.columns: 126 if not column.table and column.name in column_tables: 127 tables = column_tables[column.name] 128 coalesce = [exp.column(column.name, table=table) for table in tables] 129 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 130 131 # Ensure selects keep their output name 132 if isinstance(column.parent, exp.Select): 133 replacement = exp.alias_(replacement, alias=column.name) 134 135 scope.replace(column, replacement) 136 137 138def _expand_group_by(scope, resolver): 139 group = scope.expression.args.get("group") 140 if not group: 141 return 142 143 # Replace references to select aliases 144 def transform(node, *_): 145 if isinstance(node, exp.Column) and not node.table: 146 table = resolver.get_table(node.name) 147 148 # Source columns get priority over select aliases 149 if table: 150 node.set("table", table) 151 return node 152 153 selects = {s.alias_or_name: s for s in scope.selects} 154 155 select = selects.get(node.name) 156 if select: 157 scope.clear_cache() 158 if isinstance(select, exp.Alias): 159 select = select.this 160 return select.copy() 161 162 return node 163 164 group.transform(transform, copy=False) 165 group.set("expressions", _expand_positional_references(scope, group.expressions)) 166 scope.expression.set("group", group) 167 168 169def _expand_order_by(scope): 170 order = scope.expression.args.get("order") 171 if not order: 172 return 173 174 ordereds = order.expressions 175 for ordered, new_expression in zip( 176 ordereds, 177 _expand_positional_references(scope, (o.this for o in ordereds)), 178 ): 179 ordered.set("this", new_expression) 180 181 182def _expand_positional_references(scope, expressions): 183 new_nodes = [] 184 for node in expressions: 185 if node.is_int: 186 try: 187 select = scope.selects[int(node.name) - 1] 188 except IndexError: 189 raise OptimizeError(f"Unknown output column: {node.name}") 190 if isinstance(select, exp.Alias): 191 select = select.this 192 new_nodes.append(select.copy()) 193 scope.clear_cache() 194 else: 195 new_nodes.append(node) 196 197 return new_nodes 198 199 200def _qualify_columns(scope, resolver): 201 """Disambiguate columns, ensuring each column specifies a source""" 202 for column in scope.columns: 203 column_table = column.table 204 column_name = column.name 205 206 if column_table and column_table in scope.sources: 207 source_columns = resolver.get_source_columns(column_table) 208 if source_columns and column_name not in source_columns and "*" not in source_columns: 209 raise OptimizeError(f"Unknown column: {column_name}") 210 211 if not column_table: 212 column_table = resolver.get_table(column_name) 213 214 # column_table can be a '' because bigquery unnest has no table alias 215 if column_table: 216 column.set("table", column_table) 217 elif column_table not in scope.sources: 218 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 219 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 220 221 struct_root, *struct_fields = [ 222 val for val in reversed(list(column.args.values())) if val is not None 223 ] 224 225 if struct_root.name in scope.sources: 226 # struct is already qualified, but we still need to change the AST representation 227 struct_table = struct_root 228 struct_root, *struct_fields = struct_fields 229 else: 230 struct_table = resolver.get_table(struct_root.name) 231 232 if struct_table: 233 while column.parent and isinstance(column.parent, exp.Dot): 234 column = column.parent 235 struct_fields.append(column.expression) 236 237 new_column = exp.column(struct_root, table=struct_table) 238 for field in struct_fields: 239 new_column = exp.Dot(this=new_column, expression=field) 240 241 column.replace(new_column) 242 243 columns_missing_from_scope = [] 244 # Determine whether each reference in the order by clause is to a column or an alias. 245 for ordered in scope.find_all(exp.Ordered): 246 for column in ordered.find_all(exp.Column): 247 if ( 248 not column.table 249 and column.parent is not ordered 250 and column.name in resolver.all_columns 251 ): 252 columns_missing_from_scope.append(column) 253 254 # Determine whether each reference in the having clause is to a column or an alias. 255 for having in scope.find_all(exp.Having): 256 for column in having.find_all(exp.Column): 257 if ( 258 not column.table 259 and column.find_ancestor(exp.AggFunc) 260 and column.name in resolver.all_columns 261 ): 262 columns_missing_from_scope.append(column) 263 264 for column in columns_missing_from_scope: 265 column_table = resolver.get_table(column.name) 266 267 if column_table: 268 column.set("table", column_table) 269 270 271def _expand_stars(scope, resolver): 272 """Expand stars to lists of column selections""" 273 274 new_selections = [] 275 except_columns = {} 276 replace_columns = {} 277 278 for expression in scope.selects: 279 if isinstance(expression, exp.Star): 280 tables = list(scope.selected_sources) 281 _add_except_columns(expression, tables, except_columns) 282 _add_replace_columns(expression, tables, replace_columns) 283 elif expression.is_star: 284 tables = [expression.table] 285 _add_except_columns(expression.this, tables, except_columns) 286 _add_replace_columns(expression.this, tables, replace_columns) 287 else: 288 new_selections.append(expression) 289 continue 290 291 for table in tables: 292 if table not in scope.sources: 293 raise OptimizeError(f"Unknown table: {table}") 294 columns = resolver.get_source_columns(table, only_visible=True) 295 296 if columns and "*" not in columns: 297 table_id = id(table) 298 for name in columns: 299 if name not in except_columns.get(table_id, set()): 300 alias_ = replace_columns.get(table_id, {}).get(name, name) 301 column = exp.column(name, table) 302 new_selections.append(alias(column, alias_) if alias_ != name else column) 303 else: 304 return 305 scope.expression.set("expressions", new_selections) 306 307 308def _add_except_columns(expression, tables, except_columns): 309 except_ = expression.args.get("except") 310 311 if not except_: 312 return 313 314 columns = {e.name for e in except_} 315 316 for table in tables: 317 except_columns[id(table)] = columns 318 319 320def _add_replace_columns(expression, tables, replace_columns): 321 replace = expression.args.get("replace") 322 323 if not replace: 324 return 325 326 columns = {e.this.name: e.alias for e in replace} 327 328 for table in tables: 329 replace_columns[id(table)] = columns 330 331 332def _qualify_outputs(scope): 333 """Ensure all output columns are aliased""" 334 new_selections = [] 335 336 for i, (selection, aliased_column) in enumerate( 337 itertools.zip_longest(scope.selects, scope.outer_column_list) 338 ): 339 if isinstance(selection, exp.Subquery): 340 if not selection.output_name: 341 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 342 elif not isinstance(selection, exp.Alias) and not selection.is_star: 343 alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") 344 alias_.set("this", selection) 345 selection = alias_ 346 347 if aliased_column: 348 selection.set("alias", exp.to_identifier(aliased_column)) 349 350 new_selections.append(selection) 351 352 scope.expression.set("expressions", new_selections) 353 354 355class Resolver: 356 """ 357 Helper for resolving columns. 358 359 This is a class so we can lazily load some things and easily share them across functions. 360 """ 361 362 def __init__(self, scope, schema): 363 self.scope = scope 364 self.schema = schema 365 self._source_columns = None 366 self._unambiguous_columns = None 367 self._all_columns = None 368 369 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 370 """ 371 Get the table for a column name. 372 373 Args: 374 column_name: The column name to find the table for. 375 Returns: 376 The table name if it can be found/inferred. 377 """ 378 if self._unambiguous_columns is None: 379 self._unambiguous_columns = self._get_unambiguous_columns( 380 self._get_all_source_columns() 381 ) 382 383 table_name = self._unambiguous_columns.get(column_name) 384 385 if not table_name: 386 sources_without_schema = tuple( 387 source 388 for source, columns in self._get_all_source_columns().items() 389 if not columns or "*" in columns 390 ) 391 if len(sources_without_schema) == 1: 392 table_name = sources_without_schema[0] 393 394 if table_name not in self.scope.selected_sources: 395 return exp.to_identifier(table_name) 396 397 node, _ = self.scope.selected_sources.get(table_name) 398 399 if isinstance(node, exp.Subqueryable): 400 while node and node.alias != table_name: 401 node = node.parent 402 403 node_alias = node.args.get("alias") 404 if node_alias: 405 return node_alias.this 406 407 return exp.to_identifier( 408 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 409 ) 410 411 @property 412 def all_columns(self): 413 """All available columns of all sources in this scope""" 414 if self._all_columns is None: 415 self._all_columns = { 416 column for columns in self._get_all_source_columns().values() for column in columns 417 } 418 return self._all_columns 419 420 def get_source_columns(self, name, only_visible=False): 421 """Resolve the source columns for a given source `name`""" 422 if name not in self.scope.sources: 423 raise OptimizeError(f"Unknown table: {name}") 424 425 source = self.scope.sources[name] 426 427 # If referencing a table, return the columns from the schema 428 if isinstance(source, exp.Table): 429 return self.schema.column_names(source, only_visible) 430 431 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 432 return source.expression.alias_column_names 433 434 # Otherwise, if referencing another scope, return that scope's named selects 435 return source.expression.named_selects 436 437 def _get_all_source_columns(self): 438 if self._source_columns is None: 439 self._source_columns = { 440 k: self.get_source_columns(k) 441 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 442 } 443 return self._source_columns 444 445 def _get_unambiguous_columns(self, source_columns): 446 """ 447 Find all the unambiguous columns in sources. 448 449 Args: 450 source_columns (dict): Mapping of names to source columns 451 Returns: 452 dict: Mapping of column name to source name 453 """ 454 if not source_columns: 455 return {} 456 457 source_columns = list(source_columns.items()) 458 459 first_table, first_columns = source_columns[0] 460 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 461 all_columns = set(unambiguous_columns) 462 463 for table, columns in source_columns[1:]: 464 unique = self._find_unique_columns(columns) 465 ambiguous = set(all_columns).intersection(unique) 466 all_columns.update(columns) 467 for column in ambiguous: 468 unambiguous_columns.pop(column, None) 469 for column in unique.difference(ambiguous): 470 unambiguous_columns[column] = table 471 472 return unambiguous_columns 473 474 @staticmethod 475 def _find_unique_columns(columns): 476 """ 477 Find the unique columns in a list of columns. 478 479 Example: 480 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 481 ['a', 'c'] 482 483 This is necessary because duplicate column names are ambiguous. 484 """ 485 counts = {} 486 for column in columns: 487 counts[column] = counts.get(column, 0) + 1 488 return {column for column, count in counts.items() if count == 1}
def
qualify_columns(expression, schema):
11def qualify_columns(expression, schema): 12 """ 13 Rewrite sqlglot AST to have fully qualified columns. 14 15 Example: 16 >>> import sqlglot 17 >>> schema = {"tbl": {"col": "INT"}} 18 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 19 >>> qualify_columns(expression, schema).sql() 20 'SELECT tbl.col AS col FROM tbl' 21 22 Args: 23 expression (sqlglot.Expression): expression to qualify 24 schema (dict|sqlglot.optimizer.Schema): Database schema 25 Returns: 26 sqlglot.Expression: qualified expression 27 """ 28 schema = ensure_schema(schema) 29 30 for scope in traverse_scope(expression): 31 resolver = Resolver(scope, schema) 32 _pop_table_column_aliases(scope.ctes) 33 _pop_table_column_aliases(scope.derived_tables) 34 _expand_using(scope, resolver) 35 _qualify_columns(scope, resolver) 36 if not isinstance(scope.expression, exp.UDTF): 37 _expand_stars(scope, resolver) 38 _qualify_outputs(scope) 39 _expand_group_by(scope, resolver) 40 _expand_order_by(scope) 41 42 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression (sqlglot.Expression): expression to qualify
- schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression):
45def validate_qualify_columns(expression): 46 """Raise an `OptimizeError` if any columns aren't qualified""" 47 unqualified_columns = [] 48 for scope in traverse_scope(expression): 49 if isinstance(scope.expression, exp.Select): 50 unqualified_columns.extend(scope.unqualified_columns) 51 if scope.external_columns and not scope.is_correlated_subquery: 52 column = scope.external_columns[0] 53 raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") 54 55 if unqualified_columns: 56 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 57 return expression
Raise an OptimizeError
if any columns aren't qualified
class
Resolver:
356class Resolver: 357 """ 358 Helper for resolving columns. 359 360 This is a class so we can lazily load some things and easily share them across functions. 361 """ 362 363 def __init__(self, scope, schema): 364 self.scope = scope 365 self.schema = schema 366 self._source_columns = None 367 self._unambiguous_columns = None 368 self._all_columns = None 369 370 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 371 """ 372 Get the table for a column name. 373 374 Args: 375 column_name: The column name to find the table for. 376 Returns: 377 The table name if it can be found/inferred. 378 """ 379 if self._unambiguous_columns is None: 380 self._unambiguous_columns = self._get_unambiguous_columns( 381 self._get_all_source_columns() 382 ) 383 384 table_name = self._unambiguous_columns.get(column_name) 385 386 if not table_name: 387 sources_without_schema = tuple( 388 source 389 for source, columns in self._get_all_source_columns().items() 390 if not columns or "*" in columns 391 ) 392 if len(sources_without_schema) == 1: 393 table_name = sources_without_schema[0] 394 395 if table_name not in self.scope.selected_sources: 396 return exp.to_identifier(table_name) 397 398 node, _ = self.scope.selected_sources.get(table_name) 399 400 if isinstance(node, exp.Subqueryable): 401 while node and node.alias != table_name: 402 node = node.parent 403 404 node_alias = node.args.get("alias") 405 if node_alias: 406 return node_alias.this 407 408 return exp.to_identifier( 409 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 410 ) 411 412 @property 413 def all_columns(self): 414 """All available columns of all sources in this scope""" 415 if self._all_columns is None: 416 self._all_columns = { 417 column for columns in self._get_all_source_columns().values() for column in columns 418 } 419 return self._all_columns 420 421 def get_source_columns(self, name, only_visible=False): 422 """Resolve the source columns for a given source `name`""" 423 if name not in self.scope.sources: 424 raise OptimizeError(f"Unknown table: {name}") 425 426 source = self.scope.sources[name] 427 428 # If referencing a table, return the columns from the schema 429 if isinstance(source, exp.Table): 430 return self.schema.column_names(source, only_visible) 431 432 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 433 return source.expression.alias_column_names 434 435 # Otherwise, if referencing another scope, return that scope's named selects 436 return source.expression.named_selects 437 438 def _get_all_source_columns(self): 439 if self._source_columns is None: 440 self._source_columns = { 441 k: self.get_source_columns(k) 442 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 443 } 444 return self._source_columns 445 446 def _get_unambiguous_columns(self, source_columns): 447 """ 448 Find all the unambiguous columns in sources. 449 450 Args: 451 source_columns (dict): Mapping of names to source columns 452 Returns: 453 dict: Mapping of column name to source name 454 """ 455 if not source_columns: 456 return {} 457 458 source_columns = list(source_columns.items()) 459 460 first_table, first_columns = source_columns[0] 461 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 462 all_columns = set(unambiguous_columns) 463 464 for table, columns in source_columns[1:]: 465 unique = self._find_unique_columns(columns) 466 ambiguous = set(all_columns).intersection(unique) 467 all_columns.update(columns) 468 for column in ambiguous: 469 unambiguous_columns.pop(column, None) 470 for column in unique.difference(ambiguous): 471 unambiguous_columns[column] = table 472 473 return unambiguous_columns 474 475 @staticmethod 476 def _find_unique_columns(columns): 477 """ 478 Find the unique columns in a list of columns. 479 480 Example: 481 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 482 ['a', 'c'] 483 484 This is necessary because duplicate column names are ambiguous. 485 """ 486 counts = {} 487 for column in columns: 488 counts[column] = counts.get(column, 0) + 1 489 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
370 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 371 """ 372 Get the table for a column name. 373 374 Args: 375 column_name: The column name to find the table for. 376 Returns: 377 The table name if it can be found/inferred. 378 """ 379 if self._unambiguous_columns is None: 380 self._unambiguous_columns = self._get_unambiguous_columns( 381 self._get_all_source_columns() 382 ) 383 384 table_name = self._unambiguous_columns.get(column_name) 385 386 if not table_name: 387 sources_without_schema = tuple( 388 source 389 for source, columns in self._get_all_source_columns().items() 390 if not columns or "*" in columns 391 ) 392 if len(sources_without_schema) == 1: 393 table_name = sources_without_schema[0] 394 395 if table_name not in self.scope.selected_sources: 396 return exp.to_identifier(table_name) 397 398 node, _ = self.scope.selected_sources.get(table_name) 399 400 if isinstance(node, exp.Subqueryable): 401 while node and node.alias != table_name: 402 node = node.parent 403 404 node_alias = node.args.get("alias") 405 if node_alias: 406 return node_alias.this 407 408 return exp.to_identifier( 409 table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None 410 )
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
421 def get_source_columns(self, name, only_visible=False): 422 """Resolve the source columns for a given source `name`""" 423 if name not in self.scope.sources: 424 raise OptimizeError(f"Unknown table: {name}") 425 426 source = self.scope.sources[name] 427 428 # If referencing a table, return the columns from the schema 429 if isinstance(source, exp.Table): 430 return self.schema.column_names(source, only_visible) 431 432 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 433 return source.expression.alias_column_names 434 435 # Otherwise, if referencing another scope, return that scope's named selects 436 return source.expression.named_selects
Resolve the source columns for a given source name