sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope 12from sqlglot.schema import Schema, ensure_schema 13 14 15def qualify_columns( 16 expression: exp.Expression, 17 schema: t.Dict | Schema, 18 expand_alias_refs: bool = True, 19 infer_schema: t.Optional[bool] = None, 20) -> exp.Expression: 21 """ 22 Rewrite sqlglot AST to have fully qualified columns. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"tbl": {"col": "INT"}} 27 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 28 >>> qualify_columns(expression, schema).sql() 29 'SELECT tbl.col AS col FROM tbl' 30 31 Args: 32 expression: expression to qualify 33 schema: Database schema 34 expand_alias_refs: whether or not to expand references to aliases 35 infer_schema: whether or not to infer the schema if missing 36 Returns: 37 sqlglot.Expression: qualified expression 38 """ 39 schema = ensure_schema(schema) 40 infer_schema = schema.empty if infer_schema is None else infer_schema 41 42 for scope in traverse_scope(expression): 43 resolver = Resolver(scope, schema, infer_schema=infer_schema) 44 _pop_table_column_aliases(scope.ctes) 45 _pop_table_column_aliases(scope.derived_tables) 46 using_column_tables = _expand_using(scope, resolver) 47 48 if schema.empty and expand_alias_refs: 49 _expand_alias_refs(scope, resolver) 50 51 _qualify_columns(scope, resolver) 52 53 if not schema.empty and expand_alias_refs: 54 _expand_alias_refs(scope, resolver) 55 56 if not isinstance(scope.expression, exp.UDTF): 57 _expand_stars(scope, resolver, using_column_tables) 58 _qualify_outputs(scope) 59 _expand_group_by(scope) 60 _expand_order_by(scope, resolver) 61 62 return expression 63 64 65def validate_qualify_columns(expression: E) -> E: 66 """Raise an `OptimizeError` if any columns aren't qualified""" 67 unqualified_columns = [] 68 for scope in traverse_scope(expression): 69 if isinstance(scope.expression, exp.Select): 70 unqualified_columns.extend(scope.unqualified_columns) 71 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 72 column = scope.external_columns[0] 73 raise OptimizeError( 74 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 75 ) 76 77 if unqualified_columns: 78 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 79 return expression 80 81 82def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 83 """ 84 Remove table column aliases. 85 86 (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) 87 """ 88 for derived_table in derived_tables: 89 table_alias = derived_table.args.get("alias") 90 if table_alias: 91 table_alias.args.pop("columns", None) 92 93 94def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 95 joins = list(scope.find_all(exp.Join)) 96 names = {join.alias_or_name for join in joins} 97 ordered = [key for key in scope.selected_sources if key not in names] 98 99 # Mapping of automatically joined column names to an ordered set of source names (dict). 100 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 101 102 for join in joins: 103 using = join.args.get("using") 104 105 if not using: 106 continue 107 108 join_table = join.alias_or_name 109 110 columns = {} 111 112 for k in scope.selected_sources: 113 if k in ordered: 114 for column in resolver.get_source_columns(k): 115 if column not in columns: 116 columns[column] = k 117 118 source_table = ordered[-1] 119 ordered.append(join_table) 120 join_columns = resolver.get_source_columns(join_table) 121 conditions = [] 122 123 for identifier in using: 124 identifier = identifier.name 125 table = columns.get(identifier) 126 127 if not table or identifier not in join_columns: 128 if columns and join_columns: 129 raise OptimizeError(f"Cannot automatically join: {identifier}") 130 131 table = table or source_table 132 conditions.append( 133 exp.condition( 134 exp.EQ( 135 this=exp.column(identifier, table=table), 136 expression=exp.column(identifier, table=join_table), 137 ) 138 ) 139 ) 140 141 # Set all values in the dict to None, because we only care about the key ordering 142 tables = column_tables.setdefault(identifier, {}) 143 if table not in tables: 144 tables[table] = None 145 if join_table not in tables: 146 tables[join_table] = None 147 148 join.args.pop("using") 149 join.set("on", exp.and_(*conditions, copy=False)) 150 151 if column_tables: 152 for column in scope.columns: 153 if not column.table and column.name in column_tables: 154 tables = column_tables[column.name] 155 coalesce = [exp.column(column.name, table=table) for table in tables] 156 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 157 158 # Ensure selects keep their output name 159 if isinstance(column.parent, exp.Select): 160 replacement = alias(replacement, alias=column.name, copy=False) 161 162 scope.replace(column, replacement) 163 164 return column_tables 165 166 167def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 168 expression = scope.expression 169 170 if not isinstance(expression, exp.Select): 171 return 172 173 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 174 175 def replace_columns( 176 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 177 ) -> None: 178 if not node: 179 return 180 181 for column, *_ in walk_in_scope(node): 182 if not isinstance(column, exp.Column): 183 continue 184 table = resolver.get_table(column.name) if resolve_table and not column.table else None 185 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 186 double_agg = ( 187 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 188 if alias_expr 189 else False 190 ) 191 192 if table and (not alias_expr or double_agg): 193 column.set("table", table) 194 elif not column.table and alias_expr and not double_agg: 195 if isinstance(alias_expr, exp.Literal): 196 if literal_index: 197 column.replace(exp.Literal.number(i)) 198 else: 199 column.replace(alias_expr.copy()) 200 201 for i, projection in enumerate(scope.selects): 202 replace_columns(projection) 203 204 if isinstance(projection, exp.Alias): 205 alias_to_expression[projection.alias] = (projection.this, i + 1) 206 207 replace_columns(expression.args.get("where")) 208 replace_columns(expression.args.get("group"), literal_index=True) 209 replace_columns(expression.args.get("having"), resolve_table=True) 210 replace_columns(expression.args.get("qualify"), resolve_table=True) 211 scope.clear_cache() 212 213 214def _expand_group_by(scope: Scope): 215 expression = scope.expression 216 group = expression.args.get("group") 217 if not group: 218 return 219 220 group.set("expressions", _expand_positional_references(scope, group.expressions)) 221 expression.set("group", group) 222 223 # group by expressions cannot be simplified, for example 224 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 225 # the projection must exactly match the group by key 226 groups = set(group.expressions) 227 group.meta["final"] = True 228 229 for e in expression.selects: 230 for node, *_ in e.walk(): 231 if node in groups: 232 e.meta["final"] = True 233 break 234 235 having = expression.args.get("having") 236 if having: 237 for node, *_ in having.walk(): 238 if node in groups: 239 having.meta["final"] = True 240 break 241 242 243def _expand_order_by(scope: Scope, resolver: Resolver): 244 order = scope.expression.args.get("order") 245 if not order: 246 return 247 248 ordereds = order.expressions 249 for ordered, new_expression in zip( 250 ordereds, 251 _expand_positional_references(scope, (o.this for o in ordereds)), 252 ): 253 for agg in ordered.find_all(exp.AggFunc): 254 for col in agg.find_all(exp.Column): 255 if not col.table: 256 col.set("table", resolver.get_table(col.name)) 257 258 ordered.set("this", new_expression) 259 260 if scope.expression.args.get("group"): 261 selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} 262 263 for ordered in ordereds: 264 ordered = ordered.this 265 266 ordered.replace( 267 exp.to_identifier(_select_by_pos(scope, ordered).alias) 268 if ordered.is_int 269 else selects.get(ordered, ordered) 270 ) 271 272 273def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: 274 new_nodes = [] 275 for node in expressions: 276 if node.is_int: 277 select = _select_by_pos(scope, t.cast(exp.Literal, node)).this 278 279 if isinstance(select, exp.Literal): 280 new_nodes.append(node) 281 else: 282 new_nodes.append(select.copy()) 283 scope.clear_cache() 284 else: 285 new_nodes.append(node) 286 287 return new_nodes 288 289 290def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 291 try: 292 return scope.selects[int(node.this) - 1].assert_is(exp.Alias) 293 except IndexError: 294 raise OptimizeError(f"Unknown output column: {node.name}") 295 296 297def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 298 """Disambiguate columns, ensuring each column specifies a source""" 299 for column in scope.columns: 300 column_table = column.table 301 column_name = column.name 302 303 if column_table and column_table in scope.sources: 304 source_columns = resolver.get_source_columns(column_table) 305 if source_columns and column_name not in source_columns and "*" not in source_columns: 306 raise OptimizeError(f"Unknown column: {column_name}") 307 308 if not column_table: 309 if scope.pivots and not column.find_ancestor(exp.Pivot): 310 # If the column is under the Pivot expression, we need to qualify it 311 # using the name of the pivoted source instead of the pivot's alias 312 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 313 continue 314 315 column_table = resolver.get_table(column_name) 316 317 # column_table can be a '' because bigquery unnest has no table alias 318 if column_table: 319 column.set("table", column_table) 320 elif column_table not in scope.sources and ( 321 not scope.parent or column_table not in scope.parent.sources 322 ): 323 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 324 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 325 326 root, *parts = column.parts 327 328 if root.name in scope.sources: 329 # struct is already qualified, but we still need to change the AST representation 330 column_table = root 331 root, *parts = parts 332 else: 333 column_table = resolver.get_table(root.name) 334 335 if column_table: 336 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 337 338 for pivot in scope.pivots: 339 for column in pivot.find_all(exp.Column): 340 if not column.table and column.name in resolver.all_columns: 341 column_table = resolver.get_table(column.name) 342 if column_table: 343 column.set("table", column_table) 344 345 346def _expand_stars( 347 scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] 348) -> None: 349 """Expand stars to lists of column selections""" 350 351 new_selections = [] 352 except_columns: t.Dict[int, t.Set[str]] = {} 353 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 354 coalesced_columns = set() 355 356 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 357 pivot_columns = None 358 pivot_output_columns = None 359 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 360 361 has_pivoted_source = pivot and not pivot.args.get("unpivot") 362 if pivot and has_pivoted_source: 363 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 364 365 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 366 if not pivot_output_columns: 367 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 368 369 for expression in scope.selects: 370 if isinstance(expression, exp.Star): 371 tables = list(scope.selected_sources) 372 _add_except_columns(expression, tables, except_columns) 373 _add_replace_columns(expression, tables, replace_columns) 374 elif expression.is_star: 375 tables = [expression.table] 376 _add_except_columns(expression.this, tables, except_columns) 377 _add_replace_columns(expression.this, tables, replace_columns) 378 else: 379 new_selections.append(expression) 380 continue 381 382 for table in tables: 383 if table not in scope.sources: 384 raise OptimizeError(f"Unknown table: {table}") 385 386 columns = resolver.get_source_columns(table, only_visible=True) 387 388 # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement 389 # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table 390 if resolver.schema.dialect == "bigquery": 391 columns = [ 392 name 393 for name in columns 394 if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") 395 ] 396 397 if columns and "*" not in columns: 398 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 399 implicit_columns = [col for col in columns if col not in pivot_columns] 400 new_selections.extend( 401 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 402 for name in implicit_columns + pivot_output_columns 403 ) 404 continue 405 406 table_id = id(table) 407 for name in columns: 408 if name in using_column_tables and table in using_column_tables[name]: 409 if name in coalesced_columns: 410 continue 411 412 coalesced_columns.add(name) 413 tables = using_column_tables[name] 414 coalesce = [exp.column(name, table=table) for table in tables] 415 416 new_selections.append( 417 alias( 418 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 419 alias=name, 420 copy=False, 421 ) 422 ) 423 elif name not in except_columns.get(table_id, set()): 424 alias_ = replace_columns.get(table_id, {}).get(name, name) 425 column = exp.column(name, table=table) 426 new_selections.append( 427 alias(column, alias_, copy=False) if alias_ != name else column 428 ) 429 else: 430 return 431 432 scope.expression.set("expressions", new_selections) 433 434 435def _add_except_columns( 436 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 437) -> None: 438 except_ = expression.args.get("except") 439 440 if not except_: 441 return 442 443 columns = {e.name for e in except_} 444 445 for table in tables: 446 except_columns[id(table)] = columns 447 448 449def _add_replace_columns( 450 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 451) -> None: 452 replace = expression.args.get("replace") 453 454 if not replace: 455 return 456 457 columns = {e.this.name: e.alias for e in replace} 458 459 for table in tables: 460 replace_columns[id(table)] = columns 461 462 463def _qualify_outputs(scope: Scope): 464 """Ensure all output columns are aliased""" 465 new_selections = [] 466 467 for i, (selection, aliased_column) in enumerate( 468 itertools.zip_longest(scope.selects, scope.outer_column_list) 469 ): 470 if isinstance(selection, exp.Subquery): 471 if not selection.output_name: 472 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 473 elif not isinstance(selection, exp.Alias) and not selection.is_star: 474 selection = alias( 475 selection, 476 alias=selection.output_name or f"_col_{i}", 477 ) 478 if aliased_column: 479 selection.set("alias", exp.to_identifier(aliased_column)) 480 481 new_selections.append(selection) 482 483 scope.expression.set("expressions", new_selections) 484 485 486def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 487 """Makes sure all identifiers that need to be quoted are quoted.""" 488 return expression.transform( 489 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 490 ) 491 492 493class Resolver: 494 """ 495 Helper for resolving columns. 496 497 This is a class so we can lazily load some things and easily share them across functions. 498 """ 499 500 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 501 self.scope = scope 502 self.schema = schema 503 self._source_columns = None 504 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 505 self._all_columns = None 506 self._infer_schema = infer_schema 507 508 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 509 """ 510 Get the table for a column name. 511 512 Args: 513 column_name: The column name to find the table for. 514 Returns: 515 The table name if it can be found/inferred. 516 """ 517 if self._unambiguous_columns is None: 518 self._unambiguous_columns = self._get_unambiguous_columns( 519 self._get_all_source_columns() 520 ) 521 522 table_name = self._unambiguous_columns.get(column_name) 523 524 if not table_name and self._infer_schema: 525 sources_without_schema = tuple( 526 source 527 for source, columns in self._get_all_source_columns().items() 528 if not columns or "*" in columns 529 ) 530 if len(sources_without_schema) == 1: 531 table_name = sources_without_schema[0] 532 533 if table_name not in self.scope.selected_sources: 534 return exp.to_identifier(table_name) 535 536 node, _ = self.scope.selected_sources.get(table_name) 537 538 if isinstance(node, exp.Subqueryable): 539 while node and node.alias != table_name: 540 node = node.parent 541 542 node_alias = node.args.get("alias") 543 if node_alias: 544 return exp.to_identifier(node_alias.this) 545 546 return exp.to_identifier(table_name) 547 548 @property 549 def all_columns(self): 550 """All available columns of all sources in this scope""" 551 if self._all_columns is None: 552 self._all_columns = { 553 column for columns in self._get_all_source_columns().values() for column in columns 554 } 555 return self._all_columns 556 557 def get_source_columns(self, name, only_visible=False): 558 """Resolve the source columns for a given source `name`""" 559 if name not in self.scope.sources: 560 raise OptimizeError(f"Unknown table: {name}") 561 562 source = self.scope.sources[name] 563 564 # If referencing a table, return the columns from the schema 565 if isinstance(source, exp.Table): 566 return self.schema.column_names(source, only_visible) 567 568 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 569 return source.expression.alias_column_names 570 571 # Otherwise, if referencing another scope, return that scope's named selects 572 return source.expression.named_selects 573 574 def _get_all_source_columns(self): 575 if self._source_columns is None: 576 self._source_columns = { 577 k: self.get_source_columns(k) 578 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 579 } 580 return self._source_columns 581 582 def _get_unambiguous_columns(self, source_columns): 583 """ 584 Find all the unambiguous columns in sources. 585 586 Args: 587 source_columns (dict): Mapping of names to source columns 588 Returns: 589 dict: Mapping of column name to source name 590 """ 591 if not source_columns: 592 return {} 593 594 source_columns = list(source_columns.items()) 595 596 first_table, first_columns = source_columns[0] 597 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 598 all_columns = set(unambiguous_columns) 599 600 for table, columns in source_columns[1:]: 601 unique = self._find_unique_columns(columns) 602 ambiguous = set(all_columns).intersection(unique) 603 all_columns.update(columns) 604 for column in ambiguous: 605 unambiguous_columns.pop(column, None) 606 for column in unique.difference(ambiguous): 607 unambiguous_columns[column] = table 608 609 return unambiguous_columns 610 611 @staticmethod 612 def _find_unique_columns(columns): 613 """ 614 Find the unique columns in a list of columns. 615 616 Example: 617 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 618 ['a', 'c'] 619 620 This is necessary because duplicate column names are ambiguous. 621 """ 622 counts = {} 623 for column in columns: 624 counts[column] = counts.get(column, 0) + 1 625 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: expression to qualify 34 schema: Database schema 35 expand_alias_refs: whether or not to expand references to aliases 36 infer_schema: whether or not to infer the schema if missing 37 Returns: 38 sqlglot.Expression: qualified expression 39 """ 40 schema = ensure_schema(schema) 41 infer_schema = schema.empty if infer_schema is None else infer_schema 42 43 for scope in traverse_scope(expression): 44 resolver = Resolver(scope, schema, infer_schema=infer_schema) 45 _pop_table_column_aliases(scope.ctes) 46 _pop_table_column_aliases(scope.derived_tables) 47 using_column_tables = _expand_using(scope, resolver) 48 49 if schema.empty and expand_alias_refs: 50 _expand_alias_refs(scope, resolver) 51 52 _qualify_columns(scope, resolver) 53 54 if not schema.empty and expand_alias_refs: 55 _expand_alias_refs(scope, resolver) 56 57 if not isinstance(scope.expression, exp.UDTF): 58 _expand_stars(scope, resolver, using_column_tables) 59 _qualify_outputs(scope) 60 _expand_group_by(scope) 61 _expand_order_by(scope, resolver) 62 63 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: expression to qualify
- schema: Database schema
- expand_alias_refs: whether or not to expand references to aliases
- infer_schema: whether or not to infer the schema if missing
Returns:
sqlglot.Expression: qualified expression
def
validate_qualify_columns(expression: ~E) -> ~E:
66def validate_qualify_columns(expression: E) -> E: 67 """Raise an `OptimizeError` if any columns aren't qualified""" 68 unqualified_columns = [] 69 for scope in traverse_scope(expression): 70 if isinstance(scope.expression, exp.Select): 71 unqualified_columns.extend(scope.unqualified_columns) 72 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 73 column = scope.external_columns[0] 74 raise OptimizeError( 75 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 76 ) 77 78 if unqualified_columns: 79 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 80 return expression
Raise an OptimizeError
if any columns aren't qualified
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
487def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 488 """Makes sure all identifiers that need to be quoted are quoted.""" 489 return expression.transform( 490 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 491 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
494class Resolver: 495 """ 496 Helper for resolving columns. 497 498 This is a class so we can lazily load some things and easily share them across functions. 499 """ 500 501 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 502 self.scope = scope 503 self.schema = schema 504 self._source_columns = None 505 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 506 self._all_columns = None 507 self._infer_schema = infer_schema 508 509 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 510 """ 511 Get the table for a column name. 512 513 Args: 514 column_name: The column name to find the table for. 515 Returns: 516 The table name if it can be found/inferred. 517 """ 518 if self._unambiguous_columns is None: 519 self._unambiguous_columns = self._get_unambiguous_columns( 520 self._get_all_source_columns() 521 ) 522 523 table_name = self._unambiguous_columns.get(column_name) 524 525 if not table_name and self._infer_schema: 526 sources_without_schema = tuple( 527 source 528 for source, columns in self._get_all_source_columns().items() 529 if not columns or "*" in columns 530 ) 531 if len(sources_without_schema) == 1: 532 table_name = sources_without_schema[0] 533 534 if table_name not in self.scope.selected_sources: 535 return exp.to_identifier(table_name) 536 537 node, _ = self.scope.selected_sources.get(table_name) 538 539 if isinstance(node, exp.Subqueryable): 540 while node and node.alias != table_name: 541 node = node.parent 542 543 node_alias = node.args.get("alias") 544 if node_alias: 545 return exp.to_identifier(node_alias.this) 546 547 return exp.to_identifier(table_name) 548 549 @property 550 def all_columns(self): 551 """All available columns of all sources in this scope""" 552 if self._all_columns is None: 553 self._all_columns = { 554 column for columns in self._get_all_source_columns().values() for column in columns 555 } 556 return self._all_columns 557 558 def get_source_columns(self, name, only_visible=False): 559 """Resolve the source columns for a given source `name`""" 560 if name not in self.scope.sources: 561 raise OptimizeError(f"Unknown table: {name}") 562 563 source = self.scope.sources[name] 564 565 # If referencing a table, return the columns from the schema 566 if isinstance(source, exp.Table): 567 return self.schema.column_names(source, only_visible) 568 569 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 570 return source.expression.alias_column_names 571 572 # Otherwise, if referencing another scope, return that scope's named selects 573 return source.expression.named_selects 574 575 def _get_all_source_columns(self): 576 if self._source_columns is None: 577 self._source_columns = { 578 k: self.get_source_columns(k) 579 for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) 580 } 581 return self._source_columns 582 583 def _get_unambiguous_columns(self, source_columns): 584 """ 585 Find all the unambiguous columns in sources. 586 587 Args: 588 source_columns (dict): Mapping of names to source columns 589 Returns: 590 dict: Mapping of column name to source name 591 """ 592 if not source_columns: 593 return {} 594 595 source_columns = list(source_columns.items()) 596 597 first_table, first_columns = source_columns[0] 598 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 599 all_columns = set(unambiguous_columns) 600 601 for table, columns in source_columns[1:]: 602 unique = self._find_unique_columns(columns) 603 ambiguous = set(all_columns).intersection(unique) 604 all_columns.update(columns) 605 for column in ambiguous: 606 unambiguous_columns.pop(column, None) 607 for column in unique.difference(ambiguous): 608 unambiguous_columns[column] = table 609 610 return unambiguous_columns 611 612 @staticmethod 613 def _find_unique_columns(columns): 614 """ 615 Find the unique columns in a list of columns. 616 617 Example: 618 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 619 ['a', 'c'] 620 621 This is necessary because duplicate column names are ambiguous. 622 """ 623 counts = {} 624 for column in columns: 625 counts[column] = counts.get(column, 0) + 1 626 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
509 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 510 """ 511 Get the table for a column name. 512 513 Args: 514 column_name: The column name to find the table for. 515 Returns: 516 The table name if it can be found/inferred. 517 """ 518 if self._unambiguous_columns is None: 519 self._unambiguous_columns = self._get_unambiguous_columns( 520 self._get_all_source_columns() 521 ) 522 523 table_name = self._unambiguous_columns.get(column_name) 524 525 if not table_name and self._infer_schema: 526 sources_without_schema = tuple( 527 source 528 for source, columns in self._get_all_source_columns().items() 529 if not columns or "*" in columns 530 ) 531 if len(sources_without_schema) == 1: 532 table_name = sources_without_schema[0] 533 534 if table_name not in self.scope.selected_sources: 535 return exp.to_identifier(table_name) 536 537 node, _ = self.scope.selected_sources.get(table_name) 538 539 if isinstance(node, exp.Subqueryable): 540 while node and node.alias != table_name: 541 node = node.parent 542 543 node_alias = node.args.get("alias") 544 if node_alias: 545 return exp.to_identifier(node_alias.this) 546 547 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name, only_visible=False):
558 def get_source_columns(self, name, only_visible=False): 559 """Resolve the source columns for a given source `name`""" 560 if name not in self.scope.sources: 561 raise OptimizeError(f"Unknown table: {name}") 562 563 source = self.scope.sources[name] 564 565 # If referencing a table, return the columns from the schema 566 if isinstance(source, exp.Table): 567 return self.schema.column_names(source, only_visible) 568 569 if isinstance(source, Scope) and isinstance(source.expression, exp.Values): 570 return source.expression.alias_column_names 571 572 # Otherwise, if referencing another scope, return that scope's named selects 573 return source.expression.named_selects
Resolve the source columns for a given source name