1
0
Fork 0

Merging upstream version 24.1.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:37:09 +01:00
parent 9689eb837b
commit d5706efe6b
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
70 changed files with 55134 additions and 50721 deletions

View file

@ -65,9 +65,6 @@ except ImportError:
pretty = False
"""Whether to format generated SQL by default."""
schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]:
"""

View file

@ -25,6 +25,17 @@ if t.TYPE_CHECKING:
logger = logging.getLogger("sqlglot")
UNESCAPED_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
"\\\\": "\\",
}
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@ -145,14 +156,7 @@ class _Dialect(type):
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
klass.UNESCAPED_SEQUENCES = {
"\\a": "\a",
"\\b": "\b",
"\\f": "\f",
"\\n": "\n",
"\\r": "\r",
"\\t": "\t",
"\\v": "\v",
"\\\\": "\\",
**UNESCAPED_SEQUENCES,
**klass.UNESCAPED_SEQUENCES,
}

View file

@ -53,8 +53,9 @@ class Doris(MySQL):
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.StringToArray: rename_func("SPLIT_BY_STRING"),
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
@ -65,3 +66,477 @@ class Doris(MySQL):
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
# https://github.com/apache/doris/blob/e4f41dbf1ec03f5937fdeba2ee1454a20254015b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4#L93
RESERVED_KEYWORDS = {
"account_lock",
"account_unlock",
"add",
"adddate",
"admin",
"after",
"agg_state",
"aggregate",
"alias",
"all",
"alter",
"analyze",
"analyzed",
"and",
"anti",
"append",
"array",
"array_range",
"as",
"asc",
"at",
"authors",
"auto",
"auto_increment",
"backend",
"backends",
"backup",
"begin",
"belong",
"between",
"bigint",
"bin",
"binary",
"binlog",
"bitand",
"bitmap",
"bitmap_union",
"bitor",
"bitxor",
"blob",
"boolean",
"brief",
"broker",
"buckets",
"build",
"builtin",
"bulk",
"by",
"cached",
"call",
"cancel",
"case",
"cast",
"catalog",
"catalogs",
"chain",
"char",
"character",
"charset",
"check",
"clean",
"cluster",
"clusters",
"collate",
"collation",
"collect",
"column",
"columns",
"comment",
"commit",
"committed",
"compact",
"complete",
"config",
"connection",
"connection_id",
"consistent",
"constraint",
"constraints",
"convert",
"copy",
"count",
"create",
"creation",
"cron",
"cross",
"cube",
"current",
"current_catalog",
"current_date",
"current_time",
"current_timestamp",
"current_user",
"data",
"database",
"databases",
"date",
"date_add",
"date_ceil",
"date_diff",
"date_floor",
"date_sub",
"dateadd",
"datediff",
"datetime",
"datetimev2",
"datev2",
"datetimev1",
"datev1",
"day",
"days_add",
"days_sub",
"decimal",
"decimalv2",
"decimalv3",
"decommission",
"default",
"deferred",
"delete",
"demand",
"desc",
"describe",
"diagnose",
"disk",
"distinct",
"distinctpc",
"distinctpcsa",
"distributed",
"distribution",
"div",
"do",
"doris_internal_table_id",
"double",
"drop",
"dropp",
"dual",
"duplicate",
"dynamic",
"else",
"enable",
"encryptkey",
"encryptkeys",
"end",
"ends",
"engine",
"engines",
"enter",
"errors",
"events",
"every",
"except",
"exclude",
"execute",
"exists",
"expired",
"explain",
"export",
"extended",
"external",
"extract",
"failed_login_attempts",
"false",
"fast",
"feature",
"fields",
"file",
"filter",
"first",
"float",
"follower",
"following",
"for",
"foreign",
"force",
"format",
"free",
"from",
"frontend",
"frontends",
"full",
"function",
"functions",
"generic",
"global",
"grant",
"grants",
"graph",
"group",
"grouping",
"groups",
"hash",
"having",
"hdfs",
"help",
"histogram",
"hll",
"hll_union",
"hostname",
"hour",
"hub",
"identified",
"if",
"ignore",
"immediate",
"in",
"incremental",
"index",
"indexes",
"infile",
"inner",
"insert",
"install",
"int",
"integer",
"intermediate",
"intersect",
"interval",
"into",
"inverted",
"ipv4",
"ipv6",
"is",
"is_not_null_pred",
"is_null_pred",
"isnull",
"isolation",
"job",
"jobs",
"join",
"json",
"jsonb",
"key",
"keys",
"kill",
"label",
"largeint",
"last",
"lateral",
"ldap",
"ldap_admin_password",
"left",
"less",
"level",
"like",
"limit",
"lines",
"link",
"list",
"load",
"local",
"localtime",
"localtimestamp",
"location",
"lock",
"logical",
"low_priority",
"manual",
"map",
"match",
"match_all",
"match_any",
"match_phrase",
"match_phrase_edge",
"match_phrase_prefix",
"match_regexp",
"materialized",
"max",
"maxvalue",
"memo",
"merge",
"migrate",
"migrations",
"min",
"minus",
"minute",
"modify",
"month",
"mtmv",
"name",
"names",
"natural",
"negative",
"never",
"next",
"ngram_bf",
"no",
"non_nullable",
"not",
"null",
"nulls",
"observer",
"of",
"offset",
"on",
"only",
"open",
"optimized",
"or",
"order",
"outer",
"outfile",
"over",
"overwrite",
"parameter",
"parsed",
"partition",
"partitions",
"password",
"password_expire",
"password_history",
"password_lock_time",
"password_reuse",
"path",
"pause",
"percent",
"period",
"permissive",
"physical",
"plan",
"process",
"plugin",
"plugins",
"policy",
"preceding",
"prepare",
"primary",
"proc",
"procedure",
"processlist",
"profile",
"properties",
"property",
"quantile_state",
"quantile_union",
"query",
"quota",
"random",
"range",
"read",
"real",
"rebalance",
"recover",
"recycle",
"refresh",
"references",
"regexp",
"release",
"rename",
"repair",
"repeatable",
"replace",
"replace_if_not_null",
"replica",
"repositories",
"repository",
"resource",
"resources",
"restore",
"restrictive",
"resume",
"returns",
"revoke",
"rewritten",
"right",
"rlike",
"role",
"roles",
"rollback",
"rollup",
"routine",
"row",
"rows",
"s3",
"sample",
"schedule",
"scheduler",
"schema",
"schemas",
"second",
"select",
"semi",
"sequence",
"serializable",
"session",
"set",
"sets",
"shape",
"show",
"signed",
"skew",
"smallint",
"snapshot",
"soname",
"split",
"sql_block_rule",
"start",
"starts",
"stats",
"status",
"stop",
"storage",
"stream",
"streaming",
"string",
"struct",
"subdate",
"sum",
"superuser",
"switch",
"sync",
"system",
"table",
"tables",
"tablesample",
"tablet",
"tablets",
"task",
"tasks",
"temporary",
"terminated",
"text",
"than",
"then",
"time",
"timestamp",
"timestampadd",
"timestampdiff",
"tinyint",
"to",
"transaction",
"trash",
"tree",
"triggers",
"trim",
"true",
"truncate",
"type",
"type_cast",
"types",
"unbounded",
"uncommitted",
"uninstall",
"union",
"unique",
"unlock",
"unsigned",
"update",
"use",
"user",
"using",
"value",
"values",
"varchar",
"variables",
"variant",
"vault",
"verbose",
"version",
"view",
"warnings",
"week",
"when",
"where",
"whitelist",
"with",
"work",
"workload",
"write",
"xor",
"year",
}

View file

@ -573,6 +573,9 @@ class Hive(Dialect):
exp.OnProperty: lambda *_: "",
exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY",
exp.ParseJSON: lambda self, e: self.sql(e.this),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
}
PROPERTIES_LOCATION = {

View file

@ -670,6 +670,7 @@ class MySQL(Dialect):
return self.expression(exp.GroupConcat, this=this, separator=separator)
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = None
JOIN_HINTS = False

View file

@ -116,7 +116,10 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]" if expression.expressions else "ARRAY"
if expression.expressions:
values = self.expressions(expression, key="values", flat=True)
return f"{self.expressions(expression, flat=True)}[{values}]"
return "ARRAY"
return self.datatype_sql(expression)
@ -333,6 +336,7 @@ class Postgres(Dialect):
"REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
"REGROLE": TokenType.OBJECT_IDENTIFIER,
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
SINGLE_TOKENS = {

View file

@ -63,6 +63,9 @@ class Redshift(Postgres):
"DATE_DIFF": _build_date_delta(exp.TsOrDsDiff),
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"SPLIT_TO_ARRAY": lambda args: exp.StringToArray(
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(",")
),
"STRTOL": exp.FromBase.from_arg_list,
}
@ -124,6 +127,7 @@ class Redshift(Postgres):
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
"MINUS": TokenType.EXCEPT,
}
KEYWORDS.pop("VALUES")
@ -186,6 +190,7 @@ class Redshift(Postgres):
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.StartsWith: lambda self,
e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
exp.StringToArray: rename_func("SPLIT_TO_ARRAY"),
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),

View file

@ -473,6 +473,14 @@ class Snowflake(Dialect):
"TERSE USERS": _show_parser("USERS"),
}
CONSTRAINT_PARSERS = {
**parser.Parser.CONSTRAINT_PARSERS,
"WITH": lambda self: self._parse_with_constraint(),
"MASKING": lambda self: self._parse_with_constraint(),
"PROJECTION": lambda self: self._parse_with_constraint(),
"TAG": lambda self: self._parse_with_constraint(),
}
STAGED_FILE_SINGLE_TOKENS = {
TokenType.DOT,
TokenType.MOD,
@ -497,6 +505,29 @@ class Snowflake(Dialect):
),
}
def _parse_with_constraint(self) -> t.Optional[exp.Expression]:
if self._prev.token_type != TokenType.WITH:
self._retreat(self._index - 1)
if self._match_text_seq("MASKING", "POLICY"):
return self.expression(
exp.MaskingPolicyColumnConstraint,
this=self._parse_id_var(),
expressions=self._match(TokenType.USING)
and self._parse_wrapped_csv(self._parse_id_var),
)
if self._match_text_seq("PROJECTION", "POLICY"):
return self.expression(
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var()
)
if self._match(TokenType.TAG):
return self.expression(
exp.TagColumnConstraint,
expressions=self._parse_wrapped_csv(self._parse_property),
)
return None
def _parse_create(self) -> exp.Create | exp.Command:
expression = super()._parse_create()
if isinstance(expression, exp.Create) and expression.kind in self.NON_TABLE_CREATABLES:

View file

@ -17,7 +17,6 @@ from sqlglot.dialects.dialect import (
min_or_least,
build_date_delta,
rename_func,
timestrtotime_sql,
trim_sql,
)
from sqlglot.helper import seq_get
@ -818,6 +817,7 @@ class TSQL(Dialect):
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Repeat: rename_func("REPLICATE"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@ -834,7 +834,9 @@ class TSQL(Dialect):
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToTime: lambda self, e: self.sql(
exp.cast(e.this, exp.DataType.Type.DATETIME)
),
exp.TimeToStr: _format_sql,
exp.Trim: trim_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),

View file

@ -1632,6 +1632,7 @@ class AlterColumn(Expression):
"default": False,
"drop": False,
"comment": False,
"allow_null": False,
}
@ -1835,6 +1836,11 @@ class NotForReplicationColumnConstraint(ColumnConstraintKind):
arg_types = {}
# https://docs.snowflake.com/en/sql-reference/sql/create-table
class MaskingPolicyColumnConstraint(ColumnConstraintKind):
arg_types = {"this": True, "expressions": False}
class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
@ -1844,6 +1850,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind):
pass
# https://docs.snowflake.com/en/sql-reference/sql/create-table
class TagColumnConstraint(ColumnConstraintKind):
arg_types = {"expressions": True}
# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters
class TransformColumnConstraint(ColumnConstraintKind):
pass
@ -1869,6 +1880,11 @@ class PathColumnConstraint(ColumnConstraintKind):
pass
# https://docs.snowflake.com/en/sql-reference/sql/create-table
class ProjectionPolicyColumnConstraint(ColumnConstraintKind):
pass
# computed column expression
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
class ComputedColumnConstraint(ColumnConstraintKind):
@ -1992,7 +2008,7 @@ class Connect(Expression):
class CopyParameter(Expression):
arg_types = {"this": True, "expression": False}
arg_types = {"this": True, "expression": False, "expressions": False}
class Copy(Expression):
@ -4825,6 +4841,11 @@ class ArrayToString(Func):
_sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"]
class StringToArray(Func):
arg_types = {"this": True, "expression": True, "null": False}
_sql_names = ["STRING_TO_ARRAY", "SPLIT_BY_STRING"]
class ArrayOverlaps(Binary, Func):
pass

View file

@ -123,6 +123,8 @@ class Generator(metaclass=_Generator):
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ProjectionPolicyColumnConstraint: lambda self,
e: f"PROJECTION POLICY {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: (
@ -139,6 +141,7 @@ class Generator(metaclass=_Generator):
exp.StabilityProperty: lambda _, e: e.name,
exp.StrictProperty: lambda *_: "STRICT",
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
@ -3022,9 +3025,16 @@ class Generator(metaclass=_Generator):
if comment:
return f"ALTER COLUMN {this} COMMENT {comment}"
if not expression.args.get("drop"):
allow_null = expression.args.get("allow_null")
drop = expression.args.get("drop")
if not drop and not allow_null:
self.unsupported("Unsupported ALTER COLUMN syntax")
if allow_null is not None:
keyword = "DROP" if drop else "SET"
return f"ALTER COLUMN {this} {keyword} NOT NULL"
return f"ALTER COLUMN {this} DROP DEFAULT"
def alterdiststyle_sql(self, expression: exp.AlterDistStyle) -> str:
@ -3850,9 +3860,16 @@ class Generator(metaclass=_Generator):
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
option = self.sql(expression, "this")
if option.upper() == "FILE_FORMAT":
values = self.expressions(expression, key="expression", flat=True, sep=" ")
return f"{option} = ({values})"
if expression.expressions:
upper = option.upper()
# Snowflake FILE_FORMAT options are separated by whitespace
sep = " " if upper == "FILE_FORMAT" else ", "
# Databricks copy/format options do not set their list of values with EQ
op = " " if upper in ("COPY_OPTIONS", "FORMAT_OPTIONS") else " = "
values = self.expressions(expression, flat=True, sep=sep)
return f"{option}{op}({values})"
value = self.sql(expression, "expression")
@ -3872,9 +3889,10 @@ class Generator(metaclass=_Generator):
else:
# Snowflake case: CREDENTIALS = (...)
credentials = self.expressions(expression, key="credentials", flat=True, sep=" ")
credentials = f"CREDENTIALS = ({credentials})" if credentials else ""
credentials = f"CREDENTIALS = ({credentials})" if cred_expr is not None else ""
storage = self.sql(expression, "storage")
storage = f"STORAGE_INTEGRATION = {storage}" if storage else ""
encryption = self.expressions(expression, key="encryption", flat=True, sep=" ")
encryption = f" ENCRYPTION = ({encryption})" if encryption else ""
@ -3929,3 +3947,11 @@ class Generator(metaclass=_Generator):
on_sql = self.func("ON", filter_col, retention_period)
return f"DATA_DELETION={on_sql}"
def maskingpolicycolumnconstraint_sql(
self, expression: exp.MaskingPolicyColumnConstraint
) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f" USING ({expressions})" if expressions else ""
return f"MASKING POLICY {this}{expressions}"

View file

@ -3,7 +3,6 @@ from __future__ import annotations
import inspect
import typing as t
import sqlglot
from sqlglot import Schema, exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.optimizer.annotate_types import annotate_types
@ -72,7 +71,7 @@ def optimize(
Returns:
The optimized expression.
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
schema = ensure_schema(schema, dialect=dialect)
possible_kwargs = {
"db": db,
"catalog": catalog,

View file

@ -63,6 +63,7 @@ def qualify_columns(
if schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
_convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver)
if not schema.empty and expand_alias_refs:
@ -70,7 +71,13 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
if expand_stars:
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_expand_stars(
scope,
resolver,
using_column_tables,
pseudocolumns,
annotator,
)
qualify_outputs(scope)
_expand_group_by(scope)
@ -329,6 +336,47 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
raise OptimizeError(f"Unknown output column: {node.name}")
def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
"""
Converts `Column` instances that represent struct field lookup into chained `Dots`.
Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
"""
converted = False
for column in itertools.chain(scope.columns, scope.stars):
if isinstance(column, exp.Dot):
continue
column_table: t.Optional[str | exp.Identifier] = column.table
if (
column_table
and column_table not in scope.sources
and (
not scope.parent
or column_table not in scope.parent.sources
or not scope.is_correlated_subquery
)
):
root, *parts = column.parts
if root.name in scope.sources:
# The struct is already qualified, but we still need to change the AST
column_table = root
root, *parts = parts
else:
column_table = resolver.get_table(root.name)
if column_table:
converted = True
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
if converted:
# We want to re-aggregate the converted columns, otherwise they'd be skipped in
# a `for column in scope.columns` iteration, even though they shouldn't be
scope.clear_cache()
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
@ -347,30 +395,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
column.set("table", exp.to_identifier(scope.pivots[0].alias))
continue
column_table = resolver.get_table(column_name)
# column_table can be a '' because bigquery unnest has no table alias
column_table = resolver.get_table(column_name)
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources and (
not scope.parent
or column_table not in scope.parent.sources
or not scope.is_correlated_subquery
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
root, *parts = column.parts
if root.name in scope.sources:
# struct is already qualified, but we still need to change the AST representation
column_table = root
root, *parts = parts
else:
column_table = resolver.get_table(root.name)
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
for pivot in scope.pivots:
for column in pivot.find_all(exp.Column):
@ -380,11 +408,64 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
column.set("table", column_table)
def _expand_struct_stars(
expression: exp.Dot,
) -> t.List[exp.Alias]:
"""[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
dot_column = t.cast(exp.Column, expression.find(exp.Column))
if not dot_column.is_type(exp.DataType.Type.STRUCT):
return []
# All nested struct values are ColumnDefs, so normalize the first exp.Column in one
dot_column = dot_column.copy()
starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
# First part is the table name and last part is the star so they can be dropped
dot_parts = expression.parts[1:-1]
# If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
for part in dot_parts[1:]:
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
# Unable to expand star unless all fields are named
if not isinstance(field.this, exp.Identifier):
return []
if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
starting_struct = field
break
else:
# There is no matching field in the struct
return []
taken_names = set()
new_selections = []
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
name = field.name
# Ambiguous or anonymous fields can't be expanded
if name in taken_names or not isinstance(field.this, exp.Identifier):
return []
taken_names.add(name)
this = field.this.copy()
root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
new_column = exp.column(
t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
)
new_selections.append(alias(new_column, this, copy=False))
return new_selections
def _expand_stars(
scope: Scope,
resolver: Resolver,
using_column_tables: t.Dict[str, t.Any],
pseudocolumns: t.Set[str],
annotator: TypeAnnotator,
) -> None:
"""Expand stars to lists of column selections"""
@ -392,6 +473,7 @@ def _expand_stars(
except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
dialect = resolver.schema.dialect
pivot_output_columns = None
pivot_exclude_columns = None
@ -413,16 +495,29 @@ def _expand_stars(
if not pivot_output_columns:
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
is_bigquery = dialect == "bigquery"
if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
# Found struct expansion, annotate scope ahead of time
annotator.annotate_scope(scope)
for expression in scope.expression.selects:
tables = []
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
tables.extend(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
elif expression.is_star and not isinstance(expression, exp.Dot):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
else:
elif expression.is_star:
if not isinstance(expression, exp.Dot):
tables.append(expression.table)
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
elif is_bigquery:
struct_fields = _expand_struct_stars(expression)
if struct_fields:
new_selections.extend(struct_fields)
continue
if not tables:
new_selections.append(expression)
continue

View file

@ -86,6 +86,7 @@ class Scope:
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._stars = None
self._derived_tables = None
self._udtfs = None
self._tables = None
@ -119,14 +120,20 @@ class Scope:
self._derived_tables = []
self._udtfs = []
self._raw_columns = []
self._stars = []
self._join_hints = []
for node in self.walk(bfs=False):
if node is self.expression:
continue
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
self._raw_columns.append(node)
if isinstance(node, exp.Dot) and node.is_star:
self._stars.append(node)
elif isinstance(node, exp.Column):
if isinstance(node.this, exp.Star):
self._stars.append(node)
else:
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
self._tables.append(node)
elif isinstance(node, exp.JoinHint):
@ -231,6 +238,14 @@ class Scope:
self._ensure_collected()
return self._subqueries
@property
def stars(self) -> t.List[exp.Column | exp.Dot]:
"""
List of star expressions (columns or dots) in this scope.
"""
self._ensure_collected()
return self._stars
@property
def columns(self):
"""

View file

@ -1134,6 +1134,8 @@ class Parser(metaclass=_Parser):
SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT}
COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"}
STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False
@ -1830,11 +1832,17 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_sequence_properties()
return self.expression(
exp.Property,
this=key.to_dot() if isinstance(key, exp.Column) else key,
value=self._parse_bitwise() or self._parse_var(any_token=True),
)
# Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise
if isinstance(key, exp.Column):
key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name)
value = self._parse_bitwise() or self._parse_var(any_token=True)
# Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier())
if isinstance(value, exp.Column):
value = exp.var(value.name)
return self.expression(exp.Property, this=key, value=value)
def _parse_stored(self) -> exp.FileFormatProperty:
self._match(TokenType.ALIAS)
@ -1853,7 +1861,7 @@ class Parser(metaclass=_Parser):
),
)
def _parse_unquoted_field(self):
def _parse_unquoted_field(self) -> t.Optional[exp.Expression]:
field = self._parse_field()
if isinstance(field, exp.Identifier) and not field.quoted:
field = exp.var(field)
@ -2793,7 +2801,13 @@ class Parser(metaclass=_Parser):
if not alias and not columns:
return None
return self.expression(exp.TableAlias, this=alias, columns=columns)
table_alias = self.expression(exp.TableAlias, this=alias, columns=columns)
# We bubble up comments from the Identifier to the TableAlias
if isinstance(alias, exp.Identifier):
table_alias.add_comments(alias.pop_comments())
return table_alias
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
@ -4060,7 +4074,7 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Add | exp.Interval]:
index = self._index
if not self._match(TokenType.INTERVAL) and match_interval:
@ -4090,23 +4104,33 @@ class Parser(metaclass=_Parser):
if this and this.is_number:
this = exp.Literal.string(this.name)
elif this and this.is_string:
parts = this.name.split()
if len(parts) == 2:
parts = exp.INTERVAL_STRING_RE.findall(this.name)
if len(parts) == 1:
if unit:
# This is not actually a unit, it's something else (e.g. a "window side")
unit = None
# Unconsume the eagerly-parsed unit, since the real unit was part of the string
self._retreat(self._index - 1)
this = exp.Literal.string(parts[0])
unit = self.expression(exp.Var, this=parts[1].upper())
this = exp.Literal.string(parts[0][0])
unit = self.expression(exp.Var, this=parts[0][1].upper())
if self.INTERVAL_SPANS and self._match_text_seq("TO"):
unit = self.expression(
exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True)
)
return self.expression(exp.Interval, this=this, unit=unit)
interval = self.expression(exp.Interval, this=this, unit=unit)
index = self._index
self._match(TokenType.PLUS)
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
return self.expression(
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
)
self._retreat(index)
return interval
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
@ -4173,38 +4197,45 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
while True:
index = self._index
self._match(TokenType.PLUS)
if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
self._retreat(index)
break
interval = self.expression( # type: ignore
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
)
return interval
index = self._index
data_type = self._parse_types(check_func=True, allow_identifiers=False)
this = self._parse_column()
if data_type:
index2 = self._index
this = self._parse_primary()
if isinstance(this, exp.Literal):
parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
if not data_type.expressions:
self._retreat(index)
return self._parse_id_var() if fallback_to_identifier else self._parse_column()
# The expressions arg gets set by the parser when we have something like DECIMAL(38, 0)
# in the input SQL. In that case, we'll produce these tokens: DECIMAL ( 38 , 0 )
#
# If the index difference here is greater than 1, that means the parser itself must have
# consumed additional tokens such as the DECIMAL scale and precision in the above example.
#
# If it's not greater than 1, then it must be 1, because we've consumed at least the type
# keyword, meaning that the expressions arg of the DataType must have gotten set by a
# callable in the TYPE_CONVERTERS mapping. For example, Snowflake converts DECIMAL to
# DECIMAL(38, 0)) in order to facilitate the data type's transpilation.
#
# In these cases, we don't really want to return the converted type, but instead retreat
# and try to parse a Column or Identifier in the section below.
if data_type.expressions and index2 - index > 1:
self._retreat(index2)
return self._parse_column_ops(data_type)
return self._parse_column_ops(data_type)
self._retreat(index)
if fallback_to_identifier:
return self._parse_id_var()
this = self._parse_column()
return this and self._parse_column_ops(this)
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
@ -4268,7 +4299,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
elif nested:
expressions = self._parse_csv(
lambda: self._parse_types(
@ -4369,8 +4400,26 @@ class Parser(metaclass=_Parser):
elif expressions:
this.set("expressions", expressions)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
index = self._index
# Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3]
matched_array = self._match(TokenType.ARRAY)
while self._curr:
matched_l_bracket = self._match(TokenType.L_BRACKET)
if not matched_l_bracket and not matched_array:
break
matched_array = False
values = self._parse_csv(self._parse_conjunction) or None
if values and not schema:
self._retreat(index)
break
this = exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[this], values=values, nested=True
)
self._match(TokenType.R_BRACKET)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this)
@ -4386,15 +4435,16 @@ class Parser(metaclass=_Parser):
or self._parse_id_var()
)
self._match(TokenType.COLON)
column_def = self._parse_column_def(this)
if type_required and (
(isinstance(this, exp.Column) and this.this is column_def) or this is column_def
if (
type_required
and not isinstance(this, exp.DataType)
and not self._match_set(self.TYPE_TOKENS, advance=False)
):
self._retreat(index)
return self._parse_types()
return column_def
return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_text_seq("AT", "TIME", "ZONE"):
@ -6030,7 +6080,19 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
if self._match(TokenType.COMMENT):
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
if self._match_text_seq("DROP", "NOT", "NULL"):
return self.expression(
exp.AlterColumn,
this=column,
drop=True,
allow_null=True,
)
if self._match_text_seq("SET", "NOT", "NULL"):
return self.expression(
exp.AlterColumn,
this=column,
allow_null=False,
)
self._match_text_seq("SET", "DATA")
self._match_text_seq("TYPE")
return self.expression(
@ -6595,12 +6657,23 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithOperator, this=this, op=op)
def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]:
opts = []
self._match(TokenType.EQ)
self._match(TokenType.L_PAREN)
opts: t.List[t.Optional[exp.Expression]] = []
while self._curr and not self._match(TokenType.R_PAREN):
opts.append(self._parse_conjunction())
if self._match_text_seq("FORMAT_NAME", "="):
# The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL,
# so we parse it separately to use _parse_field()
prop = self.expression(
exp.Property, this=exp.var("FORMAT_NAME"), value=self._parse_field()
)
opts.append(prop)
else:
opts.append(self._parse_property())
self._match(TokenType.COMMA)
return opts
def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
@ -6608,37 +6681,38 @@ class Parser(metaclass=_Parser):
options = []
while self._curr and not self._match(TokenType.R_PAREN, advance=False):
option = self._parse_unquoted_field()
value = None
option = self._parse_var(any_token=True)
prev = self._prev.text.upper()
# Some options are defined as functions with the values as params
if not isinstance(option, exp.Func):
prev = self._prev.text.upper()
# Different dialects might separate options and values by white space, "=" and "AS"
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
# Different dialects might separate options and values by white space, "=" and "AS"
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN):
# Snowflake FILE_FORMAT case
value = self._parse_wrapped_options()
else:
value = self._parse_unquoted_field()
param = self.expression(exp.CopyParameter, this=option)
if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match(
TokenType.L_PAREN, advance=False
):
# Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options
param.set("expressions", self._parse_wrapped_options())
elif prev == "FILE_FORMAT":
# T-SQL's external file format case
param.set("expression", self._parse_field())
else:
param.set("expression", self._parse_unquoted_field())
param = self.expression(exp.CopyParameter, this=option, expression=value)
options.append(param)
if sep:
self._match(sep)
self._match(sep)
return options
def _parse_credentials(self) -> t.Optional[exp.Credentials]:
expr = self.expression(exp.Credentials)
if self._match_text_seq("STORAGE_INTEGRATION", advance=False):
expr.set("storage", self._parse_conjunction())
if self._match_text_seq("STORAGE_INTEGRATION", "="):
expr.set("storage", self._parse_field())
if self._match_text_seq("CREDENTIALS"):
# Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string>
# Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS <string>
creds = (
self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field()
)
@ -6661,7 +6735,7 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
this = (
self._parse_conjunction()
self._parse_select(nested=True, parse_subquery_alias=False)
if self._match(TokenType.L_PAREN, advance=False)
else self._parse_table(schema=True)
)

View file

@ -155,13 +155,16 @@ class AbstractMappingSchema:
return [table.this.name]
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
) -> t.Optional[t.Any]:
"""
Returns the schema of a given table.
Args:
table: the target table.
raise_on_missing: whether to raise in case the schema is not found.
ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
Returns:
The schema of the target table.
@ -239,6 +242,20 @@ class MappingSchema(AbstractMappingSchema, Schema):
normalize=mapping_schema.normalize,
)
def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
) -> t.Optional[t.Any]:
schema = super().find(
table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
)
if ensure_data_types and isinstance(schema, dict):
schema = {
col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
for col, dtype in schema.items()
}
return schema
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore