1
0
Fork 0

Merging upstream version 25.16.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:52:32 +01:00
parent 7688e2bdf8
commit bad79d1f7c
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
110 changed files with 75353 additions and 68092 deletions

View file

@ -128,6 +128,13 @@ class _TypeAnnotator(type):
klass.COERCES_TO[data_type] = coerces_to.copy()
coerces_to |= {data_type}
# NULL can be coerced to any type, so e.g. NULL + 1 will have type INT
klass.COERCES_TO[exp.DataType.Type.NULL] = {
*text_precedence,
*numeric_precedence,
*timelike_precedence,
}
return klass
@ -201,31 +208,47 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
expression = source.expression
if isinstance(expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
elif isinstance(source.expression, exp.Unnest):
values = [source.expression]
if isinstance(expression, exp.Lateral):
if isinstance(expression.this, exp.Explode):
values = [expression.this.this]
elif isinstance(expression, exp.Unnest):
values = [expression]
else:
values = source.expression.expressions[0].expressions
values = expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
alias: column.type
for alias, column in zip(expression.alias_column_names, values)
}
elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len(
expression.right.selects
):
if expression.args.get("by_name"):
r_type_by_select = {s.alias_or_name: s.type for s in expression.right.selects}
selects[name] = {
s.alias_or_name: self._maybe_coerce(
t.cast(exp.DataType, s.type),
r_type_by_select.get(s.alias_or_name) or exp.DataType.Type.UNKNOWN,
)
for s in expression.left.selects
}
else:
selects[name] = {
ls.alias_or_name: self._maybe_coerce(
t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type)
)
for ls, rs in zip(expression.left.selects, expression.right.selects)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
selects[name] = {s.alias_or_name: s.type for s in expression.selects}
# First annotate the current scope's column references
for col in scope.columns:
@ -237,7 +260,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(col, self.schema.get_column_type(source, col))
elif source:
if col.table in selects and col.name in selects[col.table]:
self._set_type(col, selects[col.table][col.name].type)
self._set_type(col, selects[col.table][col.name])
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)
@ -264,15 +287,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType | exp.DataType.Type:
) -> exp.DataType:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1_value, type2_value):
return exp.DataType.Type.NULL
# We propagate the UNKNOWN type upwards if found
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
return exp.DataType.build("unknown")
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value
@ -282,17 +303,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this # type: ignore
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
self._set_type(expression, exp.DataType.Type.NULL)
elif exp.DataType.Type.NULL in (left_type, right_type):
self._set_type(
expression,
exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
)
else:
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
if isinstance(expression, (exp.Connector, exp.Predicate)):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif (left_type, right_type) in self.binary_coercions:
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
@ -351,7 +362,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
last_datatype = expr_type
break
if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN):
if not expr_type.is_type(exp.DataType.Type.UNKNOWN):
last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type)
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)

View file

@ -40,6 +40,8 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
isinstance(node, (exp.Date, exp.TsOrDsToDate))
and not node.expressions
and not node.args.get("zone")
and node.this.is_string
and is_iso_date(node.this.name)
):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
@ -90,6 +92,12 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
and expression.to.this == expression.this.type.this
):
return expression.this
if (
isinstance(expression, (exp.Date, exp.TsOrDsToDate))
and expression.this.type
and expression.this.type.this == exp.DataType.Type.DATE
):
return expression.this
return expression

View file

@ -19,8 +19,12 @@ def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.I
def normalize_identifiers(expression, dialect=None):
"""
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
Normalize identifiers by converting them to either lower or upper case,
ensuring the semantics are preserved in each case (e.g. by respecting
case-sensitivity).
This transformation reflects how identifiers would be resolved by the engine corresponding
to each SQL dialect, and plays a very important role in the standardization of the AST.
It's possible to make this a no-op by adding a special comment next to the
identifier of interest:
@ -30,7 +34,7 @@ def normalize_identifiers(expression, dialect=None):
In this example, the identifier `a` will not be normalized.
Note:
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even
when they're quoted, so in these cases all identifiers are normalized.
Example:

View file

@ -30,6 +30,7 @@ def qualify(
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
infer_csv_schemas: bool = False,
) -> exp.Expression:
"""
Rewrite sqlglot AST to have normalized and qualified tables and columns.
@ -60,13 +61,21 @@ def qualify(
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
Returns:
The qualified expression.
"""
schema = ensure_schema(schema, dialect=dialect)
expression = normalize_identifiers(expression, dialect=dialect)
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect)
expression = qualify_tables(
expression,
db=db,
catalog=catalog,
schema=schema,
dialect=dialect,
infer_csv_schemas=infer_csv_schemas,
)
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

View file

@ -275,6 +275,17 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bo
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = (projection.this, i + 1)
parent_scope = scope
while parent_scope.is_union:
parent_scope = parent_scope.parent
# We shouldn't expand aliases if they match the recursive CTE's columns
if parent_scope.is_cte:
cte = parent_scope.expression.parent
if cte.find_ancestor(exp.With).recursive:
for recursive_cte_column in cte.args["alias"].columns or cte.this.selects:
alias_to_expression.pop(recursive_cte_column.output_name, None)
replace_columns(expression.args.get("where"))
replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)

View file

@ -18,6 +18,7 @@ def qualify_tables(
db: t.Optional[str | exp.Identifier] = None,
catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
infer_csv_schemas: bool = False,
dialect: DialectType = None,
) -> E:
"""
@ -39,6 +40,7 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas.
dialect: The dialect to parse catalog and schema into.
Returns:
@ -102,7 +104,7 @@ def qualify_tables(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
)
if schema and isinstance(source.this, exp.ReadCSV):
if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
columns = next(reader)

View file

@ -65,6 +65,7 @@ class Scope:
scope_type=ScopeType.ROOT,
lateral_sources=None,
cte_sources=None,
can_be_correlated=None,
):
self.expression = expression
self.sources = sources or {}
@ -81,6 +82,7 @@ class Scope:
self.cte_scopes = []
self.union_scopes = []
self.udtf_scopes = []
self.can_be_correlated = can_be_correlated
self.clear_cache()
def clear_cache(self):
@ -110,6 +112,8 @@ class Scope:
scope_type=scope_type,
cte_sources={**self.cte_sources, **(cte_sources or {})},
lateral_sources=lateral_sources.copy() if lateral_sources else None,
can_be_correlated=self.can_be_correlated
or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF),
**kwargs,
)
@ -261,7 +265,11 @@ class Scope:
external_columns = [
column
for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
for scope in itertools.chain(
self.subquery_scopes,
self.udtf_scopes,
(dts for dts in self.derived_table_scopes if dts.can_be_correlated),
)
for column in scope.external_columns
]
@ -425,10 +433,7 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
return bool(
(self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
and self.external_columns
)
return bool(self.can_be_correlated and self.external_columns)
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""