Merging upstream version 25.16.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
7688e2bdf8
commit
bad79d1f7c
110 changed files with 75353 additions and 68092 deletions
|
@ -128,6 +128,13 @@ class _TypeAnnotator(type):
|
|||
klass.COERCES_TO[data_type] = coerces_to.copy()
|
||||
coerces_to |= {data_type}
|
||||
|
||||
# NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
|
||||
klass.COERCES_TO[exp.DataType.Type.NULL] = {
|
||||
*text_precedence,
|
||||
*numeric_precedence,
|
||||
*timelike_precedence,
|
||||
}
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -201,31 +208,47 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
|
||||
expression = source.expression
|
||||
if isinstance(expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
values = [source.expression]
|
||||
if isinstance(expression, exp.Lateral):
|
||||
if isinstance(expression.this, exp.Explode):
|
||||
values = [expression.this.this]
|
||||
elif isinstance(expression, exp.Unnest):
|
||||
values = [expression]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
values = expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
alias: column.type
|
||||
for alias, column in zip(expression.alias_column_names, values)
|
||||
}
|
||||
elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
|
||||
expression.right.selects
|
||||
):
|
||||
if expression.args.get("by_name"):
|
||||
r_type_by_select = {s.alias_or_name: s.type for s in expression.right.selects}
|
||||
selects[name] = {
|
||||
s.alias_or_name: self._maybe_coerce(
|
||||
t.cast(exp.DataType, s.type),
|
||||
r_type_by_select.get(s.alias_or_name) or exp.DataType.Type.UNKNOWN,
|
||||
)
|
||||
for s in expression.left.selects
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
ls.alias_or_name: self._maybe_coerce(
|
||||
t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
|
||||
)
|
||||
for ls, rs in zip(expression.left.selects, expression.right.selects)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
selects[name] = {s.alias_or_name: s.type for s in expression.selects}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
|
@ -237,7 +260,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source:
|
||||
if col.table in selects and col.name in selects[col.table]:
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
self._set_type(col, selects[col.table][col.name])
|
||||
elif isinstance(source.expression, exp.Unnest):
|
||||
self._set_type(col, source.expression.type)
|
||||
|
||||
|
@ -264,15 +287,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
) -> exp.DataType:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1_value, type2_value):
|
||||
return exp.DataType.Type.NULL
|
||||
# We propagate the UNKNOWN type upwards if found
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
return exp.DataType.build("unknown")
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
|
||||
|
||||
|
@ -282,17 +303,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this # type: ignore
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
self._set_type(expression, exp.DataType.Type.NULL)
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
|
||||
)
|
||||
else:
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
if isinstance(expression, (exp.Connector, exp.Predicate)):
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif (left_type, right_type) in self.binary_coercions:
|
||||
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
|
||||
|
@ -351,7 +362,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
last_datatype = expr_type
|
||||
break
|
||||
|
||||
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
|
||||
if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
|
||||
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
|
|
@ -40,6 +40,8 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
isinstance(node, (exp.Date, exp.TsOrDsToDate))
|
||||
and not node.expressions
|
||||
and not node.args.get("zone")
|
||||
and node.this.is_string
|
||||
and is_iso_date(node.this.name)
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
|
||||
|
@ -90,6 +92,12 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
and expression.to.this == expression.this.type.this
|
||||
):
|
||||
return expression.this
|
||||
if (
|
||||
isinstance(expression, (exp.Date, exp.TsOrDsToDate))
|
||||
and expression.this.type
|
||||
and expression.this.type.this == exp.DataType.Type.DATE
|
||||
):
|
||||
return expression.this
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -19,8 +19,12 @@ def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.I
|
|||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
"""
|
||||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
Normalize identifiers by converting them to either lower or upper case,
|
||||
ensuring the semantics are preserved in each case (e.g. by respecting
|
||||
case-sensitivity).
|
||||
|
||||
This transformation reflects how identifiers would be resolved by the engine corresponding
|
||||
to each SQL dialect, and plays a very important role in the standardization of the AST.
|
||||
|
||||
It's possible to make this a no-op by adding a special comment next to the
|
||||
identifier of interest:
|
||||
|
@ -30,7 +34,7 @@ def normalize_identifiers(expression, dialect=None):
|
|||
In this example, the identifier `a` will not be normalized.
|
||||
|
||||
Note:
|
||||
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
|
||||
Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even
|
||||
when they're quoted, so in these cases all identifiers are normalized.
|
||||
|
||||
Example:
|
||||
|
|
|
@ -30,6 +30,7 @@ def qualify(
|
|||
validate_qualify_columns: bool = True,
|
||||
quote_identifiers: bool = True,
|
||||
identify: bool = True,
|
||||
infer_csv_schemas: bool = False,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Rewrite sqlglot AST to have normalized and qualified tables and columns.
|
||||
|
@ -60,13 +61,21 @@ def qualify(
|
|||
This step is necessary to ensure correctness for case sensitive queries.
|
||||
But this flag is provided in case this step is performed at a later time.
|
||||
identify: If True, quote all identifiers, else only necessary ones.
|
||||
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
||||
|
||||
Returns:
|
||||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect)
|
||||
expression = qualify_tables(
|
||||
expression,
|
||||
db=db,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
dialect=dialect,
|
||||
infer_csv_schemas=infer_csv_schemas,
|
||||
)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -275,6 +275,17 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
|
|||
if isinstance(projection, exp.Alias):
|
||||
alias_to_expression[projection.alias] = (projection.this, i + 1)
|
||||
|
||||
parent_scope = scope
|
||||
while parent_scope.is_union:
|
||||
parent_scope = parent_scope.parent
|
||||
|
||||
# We shouldn't expand aliases if they match the recursive CTE's columns
|
||||
if parent_scope.is_cte:
|
||||
cte = parent_scope.expression.parent
|
||||
if cte.find_ancestor(exp.With).recursive:
|
||||
for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
|
||||
alias_to_expression.pop(recursive_cte_column.output_name, None)
|
||||
|
||||
replace_columns(expression.args.get("where"))
|
||||
replace_columns(expression.args.get("group"), literal_index=True)
|
||||
replace_columns(expression.args.get("having"), resolve_table=True)
|
||||
|
|
|
@ -18,6 +18,7 @@ def qualify_tables(
|
|||
db: t.Optional[str | exp.Identifier] = None,
|
||||
catalog: t.Optional[str | exp.Identifier] = None,
|
||||
schema: t.Optional[Schema] = None,
|
||||
infer_csv_schemas: bool = False,
|
||||
dialect: DialectType = None,
|
||||
) -> E:
|
||||
"""
|
||||
|
@ -39,6 +40,7 @@ def qualify_tables(
|
|||
db: Database name
|
||||
catalog: Catalog name
|
||||
schema: A schema to populate
|
||||
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
|
||||
dialect: The dialect to parse catalog and schema into.
|
||||
|
||||
Returns:
|
||||
|
@ -102,7 +104,7 @@ def qualify_tables(
|
|||
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
|
||||
)
|
||||
|
||||
if schema and isinstance(source.this, exp.ReadCSV):
|
||||
if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
|
||||
with csv_reader(source.this) as reader:
|
||||
header = next(reader)
|
||||
columns = next(reader)
|
||||
|
|
|
@ -65,6 +65,7 @@ class Scope:
|
|||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
cte_sources=None,
|
||||
can_be_correlated=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
|
@ -81,6 +82,7 @@ class Scope:
|
|||
self.cte_scopes = []
|
||||
self.union_scopes = []
|
||||
self.udtf_scopes = []
|
||||
self.can_be_correlated = can_be_correlated
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
|
@ -110,6 +112,8 @@ class Scope:
|
|||
scope_type=scope_type,
|
||||
cte_sources={**self.cte_sources, **(cte_sources or {})},
|
||||
lateral_sources=lateral_sources.copy() if lateral_sources else None,
|
||||
can_be_correlated=self.can_be_correlated
|
||||
or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -261,7 +265,11 @@ class Scope:
|
|||
|
||||
external_columns = [
|
||||
column
|
||||
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
|
||||
for scope in itertools.chain(
|
||||
self.subquery_scopes,
|
||||
self.udtf_scopes,
|
||||
(dts for dts in self.derived_table_scopes if dts.can_be_correlated),
|
||||
)
|
||||
for column in scope.external_columns
|
||||
]
|
||||
|
||||
|
@ -425,10 +433,7 @@ class Scope:
|
|||
@property
|
||||
def is_correlated_subquery(self):
|
||||
"""Determine if this scope is a correlated subquery"""
|
||||
return bool(
|
||||
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
|
||||
and self.external_columns
|
||||
)
|
||||
return bool(self.can_be_correlated and self.external_columns)
|
||||
|
||||
def rename_source(self, old_name, new_name):
|
||||
"""Rename a source in this scope"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue