sqlglot.optimizer.qualify_columns
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import alias, exp 7from sqlglot._typing import E 8from sqlglot.dialects.dialect import Dialect, DialectType 9from sqlglot.errors import OptimizeError 10from sqlglot.helper import seq_get 11from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope 12from sqlglot.optimizer.simplify import simplify_parens 13from sqlglot.schema import Schema, ensure_schema 14 15 16def qualify_columns( 17 expression: exp.Expression, 18 schema: t.Dict | Schema, 19 expand_alias_refs: bool = True, 20 infer_schema: t.Optional[bool] = None, 21) -> exp.Expression: 22 """ 23 Rewrite sqlglot AST to have fully qualified columns. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"tbl": {"col": "INT"}} 28 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 29 >>> qualify_columns(expression, schema).sql() 30 'SELECT tbl.col AS col FROM tbl' 31 32 Args: 33 expression: Expression to qualify. 34 schema: Database schema. 35 expand_alias_refs: Whether or not to expand references to aliases. 36 infer_schema: Whether or not to infer the schema if missing. 37 38 Returns: 39 The qualified expression. 40 """ 41 schema = ensure_schema(schema) 42 infer_schema = schema.empty if infer_schema is None else infer_schema 43 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 44 45 for scope in traverse_scope(expression): 46 resolver = Resolver(scope, schema, infer_schema=infer_schema) 47 _pop_table_column_aliases(scope.ctes) 48 _pop_table_column_aliases(scope.derived_tables) 49 using_column_tables = _expand_using(scope, resolver) 50 51 if schema.empty and expand_alias_refs: 52 _expand_alias_refs(scope, resolver) 53 54 _qualify_columns(scope, resolver) 55 56 if not schema.empty and expand_alias_refs: 57 _expand_alias_refs(scope, resolver) 58 59 if not isinstance(scope.expression, exp.UDTF): 60 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 61 qualify_outputs(scope) 62 63 _expand_group_by(scope) 64 _expand_order_by(scope, resolver) 65 66 return expression 67 68 69def validate_qualify_columns(expression: E) -> E: 70 """Raise an `OptimizeError` if any columns aren't qualified""" 71 unqualified_columns = [] 72 for scope in traverse_scope(expression): 73 if isinstance(scope.expression, exp.Select): 74 unqualified_columns.extend(scope.unqualified_columns) 75 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 76 column = scope.external_columns[0] 77 raise OptimizeError( 78 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 79 ) 80 81 if unqualified_columns: 82 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 83 return expression 84 85 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: 87 """ 88 Remove table column aliases. 89 90 For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) 91 """ 92 for derived_table in derived_tables: 93 table_alias = derived_table.args.get("alias") 94 if table_alias: 95 table_alias.args.pop("columns", None) 96 97 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: 99 joins = list(scope.find_all(exp.Join)) 100 names = {join.alias_or_name for join in joins} 101 ordered = [key for key in scope.selected_sources if key not in names] 102 103 # Mapping of automatically joined column names to an ordered set of source names (dict). 104 column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} 105 106 for join in joins: 107 using = join.args.get("using") 108 109 if not using: 110 continue 111 112 join_table = join.alias_or_name 113 114 columns = {} 115 116 for source_name in scope.selected_sources: 117 if source_name in ordered: 118 for column_name in resolver.get_source_columns(source_name): 119 if column_name not in columns: 120 columns[column_name] = source_name 121 122 source_table = ordered[-1] 123 ordered.append(join_table) 124 join_columns = resolver.get_source_columns(join_table) 125 conditions = [] 126 127 for identifier in using: 128 identifier = identifier.name 129 table = columns.get(identifier) 130 131 if not table or identifier not in join_columns: 132 if (columns and "*" not in columns) and join_columns: 133 raise OptimizeError(f"Cannot automatically join: {identifier}") 134 135 table = table or source_table 136 conditions.append( 137 exp.condition( 138 exp.EQ( 139 this=exp.column(identifier, table=table), 140 expression=exp.column(identifier, table=join_table), 141 ) 142 ) 143 ) 144 145 # Set all values in the dict to None, because we only care about the key ordering 146 tables = column_tables.setdefault(identifier, {}) 147 if table not in tables: 148 tables[table] = None 149 if join_table not in tables: 150 tables[join_table] = None 151 152 join.args.pop("using") 153 join.set("on", exp.and_(*conditions, copy=False)) 154 155 if column_tables: 156 for column in scope.columns: 157 if not column.table and column.name in column_tables: 158 tables = column_tables[column.name] 159 coalesce = [exp.column(column.name, table=table) for table in tables] 160 replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) 161 162 # Ensure selects keep their output name 163 if isinstance(column.parent, exp.Select): 164 replacement = alias(replacement, alias=column.name, copy=False) 165 166 scope.replace(column, replacement) 167 168 return column_tables 169 170 171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: 172 expression = scope.expression 173 174 if not isinstance(expression, exp.Select): 175 return 176 177 alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} 178 179 def replace_columns( 180 node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False 181 ) -> None: 182 if not node: 183 return 184 185 for column, *_ in walk_in_scope(node): 186 if not isinstance(column, exp.Column): 187 continue 188 189 table = resolver.get_table(column.name) if resolve_table and not column.table else None 190 alias_expr, i = alias_to_expression.get(column.name, (None, 1)) 191 double_agg = ( 192 (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) 193 if alias_expr 194 else False 195 ) 196 197 if table and (not alias_expr or double_agg): 198 column.set("table", table) 199 elif not column.table and alias_expr and not double_agg: 200 if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): 201 if literal_index: 202 column.replace(exp.Literal.number(i)) 203 else: 204 column = column.replace(exp.paren(alias_expr)) 205 simplified = simplify_parens(column) 206 if simplified is not column: 207 column.replace(simplified) 208 209 for i, projection in enumerate(scope.expression.selects): 210 replace_columns(projection) 211 212 if isinstance(projection, exp.Alias): 213 alias_to_expression[projection.alias] = (projection.this, i + 1) 214 215 replace_columns(expression.args.get("where")) 216 replace_columns(expression.args.get("group"), literal_index=True) 217 replace_columns(expression.args.get("having"), resolve_table=True) 218 replace_columns(expression.args.get("qualify"), resolve_table=True) 219 scope.clear_cache() 220 221 222def _expand_group_by(scope: Scope) -> None: 223 expression = scope.expression 224 group = expression.args.get("group") 225 if not group: 226 return 227 228 group.set("expressions", _expand_positional_references(scope, group.expressions)) 229 expression.set("group", group) 230 231 232def _expand_order_by(scope: Scope, resolver: Resolver) -> None: 233 order = scope.expression.args.get("order") 234 if not order: 235 return 236 237 ordereds = order.expressions 238 for ordered, new_expression in zip( 239 ordereds, 240 _expand_positional_references(scope, (o.this for o in ordereds), alias=True), 241 ): 242 for agg in ordered.find_all(exp.AggFunc): 243 for col in agg.find_all(exp.Column): 244 if not col.table: 245 col.set("table", resolver.get_table(col.name)) 246 247 ordered.set("this", new_expression) 248 249 if scope.expression.args.get("group"): 250 selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} 251 252 for ordered in ordereds: 253 ordered = ordered.this 254 255 ordered.replace( 256 exp.to_identifier(_select_by_pos(scope, ordered).alias) 257 if ordered.is_int 258 else selects.get(ordered, ordered) 259 ) 260 261 262def _expand_positional_references( 263 scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False 264) -> t.List[exp.Expression]: 265 new_nodes: t.List[exp.Expression] = [] 266 for node in expressions: 267 if node.is_int: 268 select = _select_by_pos(scope, t.cast(exp.Literal, node)) 269 270 if alias: 271 new_nodes.append(exp.column(select.args["alias"].copy())) 272 else: 273 select = select.this 274 275 if isinstance(select, exp.Literal): 276 new_nodes.append(node) 277 else: 278 new_nodes.append(select.copy()) 279 else: 280 new_nodes.append(node) 281 282 return new_nodes 283 284 285def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: 286 try: 287 return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) 288 except IndexError: 289 raise OptimizeError(f"Unknown output column: {node.name}") 290 291 292def _qualify_columns(scope: Scope, resolver: Resolver) -> None: 293 """Disambiguate columns, ensuring each column specifies a source""" 294 for column in scope.columns: 295 column_table = column.table 296 column_name = column.name 297 298 if column_table and column_table in scope.sources: 299 source_columns = resolver.get_source_columns(column_table) 300 if source_columns and column_name not in source_columns and "*" not in source_columns: 301 raise OptimizeError(f"Unknown column: {column_name}") 302 303 if not column_table: 304 if scope.pivots and not column.find_ancestor(exp.Pivot): 305 # If the column is under the Pivot expression, we need to qualify it 306 # using the name of the pivoted source instead of the pivot's alias 307 column.set("table", exp.to_identifier(scope.pivots[0].alias)) 308 continue 309 310 column_table = resolver.get_table(column_name) 311 312 # column_table can be a '' because bigquery unnest has no table alias 313 if column_table: 314 column.set("table", column_table) 315 elif column_table not in scope.sources and ( 316 not scope.parent 317 or column_table not in scope.parent.sources 318 or not scope.is_correlated_subquery 319 ): 320 # structs are used like tables (e.g. "struct"."field"), so they need to be qualified 321 # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) 322 323 root, *parts = column.parts 324 325 if root.name in scope.sources: 326 # struct is already qualified, but we still need to change the AST representation 327 column_table = root 328 root, *parts = parts 329 else: 330 column_table = resolver.get_table(root.name) 331 332 if column_table: 333 column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) 334 335 for pivot in scope.pivots: 336 for column in pivot.find_all(exp.Column): 337 if not column.table and column.name in resolver.all_columns: 338 column_table = resolver.get_table(column.name) 339 if column_table: 340 column.set("table", column_table) 341 342 343def _expand_stars( 344 scope: Scope, 345 resolver: Resolver, 346 using_column_tables: t.Dict[str, t.Any], 347 pseudocolumns: t.Set[str], 348) -> None: 349 """Expand stars to lists of column selections""" 350 351 new_selections = [] 352 except_columns: t.Dict[int, t.Set[str]] = {} 353 replace_columns: t.Dict[int, t.Dict[str, str]] = {} 354 coalesced_columns = set() 355 356 # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future 357 pivot_columns = None 358 pivot_output_columns = None 359 pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) 360 361 has_pivoted_source = pivot and not pivot.args.get("unpivot") 362 if pivot and has_pivoted_source: 363 pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) 364 365 pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] 366 if not pivot_output_columns: 367 pivot_output_columns = [col.alias_or_name for col in pivot.expressions] 368 369 for expression in scope.expression.selects: 370 if isinstance(expression, exp.Star): 371 tables = list(scope.selected_sources) 372 _add_except_columns(expression, tables, except_columns) 373 _add_replace_columns(expression, tables, replace_columns) 374 elif expression.is_star: 375 tables = [expression.table] 376 _add_except_columns(expression.this, tables, except_columns) 377 _add_replace_columns(expression.this, tables, replace_columns) 378 else: 379 new_selections.append(expression) 380 continue 381 382 for table in tables: 383 if table not in scope.sources: 384 raise OptimizeError(f"Unknown table: {table}") 385 386 columns = resolver.get_source_columns(table, only_visible=True) 387 388 if pseudocolumns: 389 columns = [name for name in columns if name.upper() not in pseudocolumns] 390 391 if columns and "*" not in columns: 392 table_id = id(table) 393 columns_to_exclude = except_columns.get(table_id) or set() 394 395 if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: 396 implicit_columns = [col for col in columns if col not in pivot_columns] 397 new_selections.extend( 398 exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) 399 for name in implicit_columns + pivot_output_columns 400 if name not in columns_to_exclude 401 ) 402 continue 403 404 for name in columns: 405 if name in using_column_tables and table in using_column_tables[name]: 406 if name in coalesced_columns: 407 continue 408 409 coalesced_columns.add(name) 410 tables = using_column_tables[name] 411 coalesce = [exp.column(name, table=table) for table in tables] 412 413 new_selections.append( 414 alias( 415 exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), 416 alias=name, 417 copy=False, 418 ) 419 ) 420 elif name not in columns_to_exclude: 421 alias_ = replace_columns.get(table_id, {}).get(name, name) 422 column = exp.column(name, table=table) 423 new_selections.append( 424 alias(column, alias_, copy=False) if alias_ != name else column 425 ) 426 else: 427 return 428 429 # Ensures we don't overwrite the initial selections with an empty list 430 if new_selections: 431 scope.expression.set("expressions", new_selections) 432 433 434def _add_except_columns( 435 expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] 436) -> None: 437 except_ = expression.args.get("except") 438 439 if not except_: 440 return 441 442 columns = {e.name for e in except_} 443 444 for table in tables: 445 except_columns[id(table)] = columns 446 447 448def _add_replace_columns( 449 expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]] 450) -> None: 451 replace = expression.args.get("replace") 452 453 if not replace: 454 return 455 456 columns = {e.this.name: e.alias for e in replace} 457 458 for table in tables: 459 replace_columns[id(table)] = columns 460 461 462def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 463 """Ensure all output columns are aliased""" 464 if isinstance(scope_or_expression, exp.Expression): 465 scope = build_scope(scope_or_expression) 466 if not isinstance(scope, Scope): 467 return 468 else: 469 scope = scope_or_expression 470 471 new_selections = [] 472 for i, (selection, aliased_column) in enumerate( 473 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 474 ): 475 if isinstance(selection, exp.Subquery): 476 if not selection.output_name: 477 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 478 elif not isinstance(selection, exp.Alias) and not selection.is_star: 479 selection = alias( 480 selection, 481 alias=selection.output_name or f"_col_{i}", 482 ) 483 if aliased_column: 484 selection.set("alias", exp.to_identifier(aliased_column)) 485 486 new_selections.append(selection) 487 488 scope.expression.set("expressions", new_selections) 489 490 491def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 492 """Makes sure all identifiers that need to be quoted are quoted.""" 493 return expression.transform( 494 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 495 ) 496 497 498class Resolver: 499 """ 500 Helper for resolving columns. 501 502 This is a class so we can lazily load some things and easily share them across functions. 503 """ 504 505 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 506 self.scope = scope 507 self.schema = schema 508 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 509 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 510 self._all_columns: t.Optional[t.Set[str]] = None 511 self._infer_schema = infer_schema 512 513 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 514 """ 515 Get the table for a column name. 516 517 Args: 518 column_name: The column name to find the table for. 519 Returns: 520 The table name if it can be found/inferred. 521 """ 522 if self._unambiguous_columns is None: 523 self._unambiguous_columns = self._get_unambiguous_columns( 524 self._get_all_source_columns() 525 ) 526 527 table_name = self._unambiguous_columns.get(column_name) 528 529 if not table_name and self._infer_schema: 530 sources_without_schema = tuple( 531 source 532 for source, columns in self._get_all_source_columns().items() 533 if not columns or "*" in columns 534 ) 535 if len(sources_without_schema) == 1: 536 table_name = sources_without_schema[0] 537 538 if table_name not in self.scope.selected_sources: 539 return exp.to_identifier(table_name) 540 541 node, _ = self.scope.selected_sources.get(table_name) 542 543 if isinstance(node, exp.Subqueryable): 544 while node and node.alias != table_name: 545 node = node.parent 546 547 node_alias = node.args.get("alias") 548 if node_alias: 549 return exp.to_identifier(node_alias.this) 550 551 return exp.to_identifier(table_name) 552 553 @property 554 def all_columns(self) -> t.Set[str]: 555 """All available columns of all sources in this scope""" 556 if self._all_columns is None: 557 self._all_columns = { 558 column for columns in self._get_all_source_columns().values() for column in columns 559 } 560 return self._all_columns 561 562 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 563 """Resolve the source columns for a given source `name`.""" 564 if name not in self.scope.sources: 565 raise OptimizeError(f"Unknown table: {name}") 566 567 source = self.scope.sources[name] 568 569 if isinstance(source, exp.Table): 570 columns = self.schema.column_names(source, only_visible) 571 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 572 columns = source.expression.alias_column_names 573 else: 574 columns = source.expression.named_selects 575 576 node, _ = self.scope.selected_sources.get(name) or (None, None) 577 if isinstance(node, Scope): 578 column_aliases = node.expression.alias_column_names 579 elif isinstance(node, exp.Expression): 580 column_aliases = node.alias_column_names 581 else: 582 column_aliases = [] 583 584 # If the source's columns are aliased, their aliases shadow the corresponding column names 585 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 586 587 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 588 if self._source_columns is None: 589 self._source_columns = { 590 source_name: self.get_source_columns(source_name) 591 for source_name, source in itertools.chain( 592 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 593 ) 594 } 595 return self._source_columns 596 597 def _get_unambiguous_columns( 598 self, source_columns: t.Dict[str, t.List[str]] 599 ) -> t.Dict[str, str]: 600 """ 601 Find all the unambiguous columns in sources. 602 603 Args: 604 source_columns: Mapping of names to source columns. 605 606 Returns: 607 Mapping of column name to source name. 608 """ 609 if not source_columns: 610 return {} 611 612 source_columns_pairs = list(source_columns.items()) 613 614 first_table, first_columns = source_columns_pairs[0] 615 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 616 all_columns = set(unambiguous_columns) 617 618 for table, columns in source_columns_pairs[1:]: 619 unique = self._find_unique_columns(columns) 620 ambiguous = set(all_columns).intersection(unique) 621 all_columns.update(columns) 622 623 for column in ambiguous: 624 unambiguous_columns.pop(column, None) 625 for column in unique.difference(ambiguous): 626 unambiguous_columns[column] = table 627 628 return unambiguous_columns 629 630 @staticmethod 631 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 632 """ 633 Find the unique columns in a list of columns. 634 635 Example: 636 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 637 ['a', 'c'] 638 639 This is necessary because duplicate column names are ambiguous. 640 """ 641 counts: t.Dict[str, int] = {} 642 for column in columns: 643 counts[column] = counts.get(column, 0) + 1 644 return {column for column, count in counts.items() if count == 1}
def
qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns( 18 expression: exp.Expression, 19 schema: t.Dict | Schema, 20 expand_alias_refs: bool = True, 21 infer_schema: t.Optional[bool] = None, 22) -> exp.Expression: 23 """ 24 Rewrite sqlglot AST to have fully qualified columns. 25 26 Example: 27 >>> import sqlglot 28 >>> schema = {"tbl": {"col": "INT"}} 29 >>> expression = sqlglot.parse_one("SELECT col FROM tbl") 30 >>> qualify_columns(expression, schema).sql() 31 'SELECT tbl.col AS col FROM tbl' 32 33 Args: 34 expression: Expression to qualify. 35 schema: Database schema. 36 expand_alias_refs: Whether or not to expand references to aliases. 37 infer_schema: Whether or not to infer the schema if missing. 38 39 Returns: 40 The qualified expression. 41 """ 42 schema = ensure_schema(schema) 43 infer_schema = schema.empty if infer_schema is None else infer_schema 44 pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS 45 46 for scope in traverse_scope(expression): 47 resolver = Resolver(scope, schema, infer_schema=infer_schema) 48 _pop_table_column_aliases(scope.ctes) 49 _pop_table_column_aliases(scope.derived_tables) 50 using_column_tables = _expand_using(scope, resolver) 51 52 if schema.empty and expand_alias_refs: 53 _expand_alias_refs(scope, resolver) 54 55 _qualify_columns(scope, resolver) 56 57 if not schema.empty and expand_alias_refs: 58 _expand_alias_refs(scope, resolver) 59 60 if not isinstance(scope.expression, exp.UDTF): 61 _expand_stars(scope, resolver, using_column_tables, pseudocolumns) 62 qualify_outputs(scope) 63 64 _expand_group_by(scope) 65 _expand_order_by(scope, resolver) 66 67 return expression
Rewrite sqlglot AST to have fully qualified columns.
Example:
>>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl'
Arguments:
- expression: Expression to qualify.
- schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
def
validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E: 71 """Raise an `OptimizeError` if any columns aren't qualified""" 72 unqualified_columns = [] 73 for scope in traverse_scope(expression): 74 if isinstance(scope.expression, exp.Select): 75 unqualified_columns.extend(scope.unqualified_columns) 76 if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: 77 column = scope.external_columns[0] 78 raise OptimizeError( 79 f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" 80 ) 81 82 if unqualified_columns: 83 raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") 84 return expression
Raise an OptimizeError
if any columns aren't qualified
def
qualify_outputs( scope_or_expression: sqlglot.optimizer.scope.Scope | sqlglot.expressions.Expression) -> None:
463def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: 464 """Ensure all output columns are aliased""" 465 if isinstance(scope_or_expression, exp.Expression): 466 scope = build_scope(scope_or_expression) 467 if not isinstance(scope, Scope): 468 return 469 else: 470 scope = scope_or_expression 471 472 new_selections = [] 473 for i, (selection, aliased_column) in enumerate( 474 itertools.zip_longest(scope.expression.selects, scope.outer_column_list) 475 ): 476 if isinstance(selection, exp.Subquery): 477 if not selection.output_name: 478 selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) 479 elif not isinstance(selection, exp.Alias) and not selection.is_star: 480 selection = alias( 481 selection, 482 alias=selection.output_name or f"_col_{i}", 483 ) 484 if aliased_column: 485 selection.set("alias", exp.to_identifier(aliased_column)) 486 487 new_selections.append(selection) 488 489 scope.expression.set("expressions", new_selections)
Ensure all output columns are aliased
def
quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
492def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: 493 """Makes sure all identifiers that need to be quoted are quoted.""" 494 return expression.transform( 495 Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False 496 )
Makes sure all identifiers that need to be quoted are quoted.
class
Resolver:
499class Resolver: 500 """ 501 Helper for resolving columns. 502 503 This is a class so we can lazily load some things and easily share them across functions. 504 """ 505 506 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 507 self.scope = scope 508 self.schema = schema 509 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 510 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 511 self._all_columns: t.Optional[t.Set[str]] = None 512 self._infer_schema = infer_schema 513 514 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 515 """ 516 Get the table for a column name. 517 518 Args: 519 column_name: The column name to find the table for. 520 Returns: 521 The table name if it can be found/inferred. 522 """ 523 if self._unambiguous_columns is None: 524 self._unambiguous_columns = self._get_unambiguous_columns( 525 self._get_all_source_columns() 526 ) 527 528 table_name = self._unambiguous_columns.get(column_name) 529 530 if not table_name and self._infer_schema: 531 sources_without_schema = tuple( 532 source 533 for source, columns in self._get_all_source_columns().items() 534 if not columns or "*" in columns 535 ) 536 if len(sources_without_schema) == 1: 537 table_name = sources_without_schema[0] 538 539 if table_name not in self.scope.selected_sources: 540 return exp.to_identifier(table_name) 541 542 node, _ = self.scope.selected_sources.get(table_name) 543 544 if isinstance(node, exp.Subqueryable): 545 while node and node.alias != table_name: 546 node = node.parent 547 548 node_alias = node.args.get("alias") 549 if node_alias: 550 return exp.to_identifier(node_alias.this) 551 552 return exp.to_identifier(table_name) 553 554 @property 555 def all_columns(self) -> t.Set[str]: 556 """All available columns of all sources in this scope""" 557 if self._all_columns is None: 558 self._all_columns = { 559 column for columns in self._get_all_source_columns().values() for column in columns 560 } 561 return self._all_columns 562 563 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 564 """Resolve the source columns for a given source `name`.""" 565 if name not in self.scope.sources: 566 raise OptimizeError(f"Unknown table: {name}") 567 568 source = self.scope.sources[name] 569 570 if isinstance(source, exp.Table): 571 columns = self.schema.column_names(source, only_visible) 572 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 573 columns = source.expression.alias_column_names 574 else: 575 columns = source.expression.named_selects 576 577 node, _ = self.scope.selected_sources.get(name) or (None, None) 578 if isinstance(node, Scope): 579 column_aliases = node.expression.alias_column_names 580 elif isinstance(node, exp.Expression): 581 column_aliases = node.alias_column_names 582 else: 583 column_aliases = [] 584 585 # If the source's columns are aliased, their aliases shadow the corresponding column names 586 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] 587 588 def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: 589 if self._source_columns is None: 590 self._source_columns = { 591 source_name: self.get_source_columns(source_name) 592 for source_name, source in itertools.chain( 593 self.scope.selected_sources.items(), self.scope.lateral_sources.items() 594 ) 595 } 596 return self._source_columns 597 598 def _get_unambiguous_columns( 599 self, source_columns: t.Dict[str, t.List[str]] 600 ) -> t.Dict[str, str]: 601 """ 602 Find all the unambiguous columns in sources. 603 604 Args: 605 source_columns: Mapping of names to source columns. 606 607 Returns: 608 Mapping of column name to source name. 609 """ 610 if not source_columns: 611 return {} 612 613 source_columns_pairs = list(source_columns.items()) 614 615 first_table, first_columns = source_columns_pairs[0] 616 unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} 617 all_columns = set(unambiguous_columns) 618 619 for table, columns in source_columns_pairs[1:]: 620 unique = self._find_unique_columns(columns) 621 ambiguous = set(all_columns).intersection(unique) 622 all_columns.update(columns) 623 624 for column in ambiguous: 625 unambiguous_columns.pop(column, None) 626 for column in unique.difference(ambiguous): 627 unambiguous_columns[column] = table 628 629 return unambiguous_columns 630 631 @staticmethod 632 def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: 633 """ 634 Find the unique columns in a list of columns. 635 636 Example: 637 >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) 638 ['a', 'c'] 639 640 This is necessary because duplicate column names are ambiguous. 641 """ 642 counts: t.Dict[str, int] = {} 643 for column in columns: 644 counts[column] = counts.get(column, 0) + 1 645 return {column for column, count in counts.items() if count == 1}
Helper for resolving columns.
This is a class so we can lazily load some things and easily share them across functions.
Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
506 def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): 507 self.scope = scope 508 self.schema = schema 509 self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None 510 self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None 511 self._all_columns: t.Optional[t.Set[str]] = None 512 self._infer_schema = infer_schema
514 def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: 515 """ 516 Get the table for a column name. 517 518 Args: 519 column_name: The column name to find the table for. 520 Returns: 521 The table name if it can be found/inferred. 522 """ 523 if self._unambiguous_columns is None: 524 self._unambiguous_columns = self._get_unambiguous_columns( 525 self._get_all_source_columns() 526 ) 527 528 table_name = self._unambiguous_columns.get(column_name) 529 530 if not table_name and self._infer_schema: 531 sources_without_schema = tuple( 532 source 533 for source, columns in self._get_all_source_columns().items() 534 if not columns or "*" in columns 535 ) 536 if len(sources_without_schema) == 1: 537 table_name = sources_without_schema[0] 538 539 if table_name not in self.scope.selected_sources: 540 return exp.to_identifier(table_name) 541 542 node, _ = self.scope.selected_sources.get(table_name) 543 544 if isinstance(node, exp.Subqueryable): 545 while node and node.alias != table_name: 546 node = node.parent 547 548 node_alias = node.args.get("alias") 549 if node_alias: 550 return exp.to_identifier(node_alias.this) 551 552 return exp.to_identifier(table_name)
Get the table for a column name.
Arguments:
- column_name: The column name to find the table for.
Returns:
The table name if it can be found/inferred.
def
get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
563 def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: 564 """Resolve the source columns for a given source `name`.""" 565 if name not in self.scope.sources: 566 raise OptimizeError(f"Unknown table: {name}") 567 568 source = self.scope.sources[name] 569 570 if isinstance(source, exp.Table): 571 columns = self.schema.column_names(source, only_visible) 572 elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): 573 columns = source.expression.alias_column_names 574 else: 575 columns = source.expression.named_selects 576 577 node, _ = self.scope.selected_sources.get(name) or (None, None) 578 if isinstance(node, Scope): 579 column_aliases = node.expression.alias_column_names 580 elif isinstance(node, exp.Expression): 581 column_aliases = node.alias_column_names 582 else: 583 column_aliases = [] 584 585 # If the source's columns are aliased, their aliases shadow the corresponding column names 586 return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
Resolve the source columns for a given source name
.