Merging upstream version 24.1.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
9689eb837b
commit
d5706efe6b
70 changed files with 55134 additions and 50721 deletions
|
@ -65,9 +65,6 @@ except ImportError:
|
|||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
||||
schema = MappingSchema()
|
||||
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||
|
||||
|
||||
def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]:
|
||||
"""
|
||||
|
|
|
@ -25,6 +25,17 @@ if t.TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
UNESCAPED_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
"\\\\": "\\",
|
||||
}
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
"""Dialects supported by SQLGLot."""
|
||||
|
@ -145,14 +156,7 @@ class _Dialect(type):
|
|||
|
||||
if "\\" in klass.tokenizer_class.STRING_ESCAPES:
|
||||
klass.UNESCAPED_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
"\\\\": "\\",
|
||||
**UNESCAPED_SEQUENCES,
|
||||
**klass.UNESCAPED_SEQUENCES,
|
||||
}
|
||||
|
||||
|
|
|
@ -53,8 +53,9 @@ class Doris(MySQL):
|
|||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.StringToArray: rename_func("SPLIT_BY_STRING"),
|
||||
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
|
||||
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
|
||||
|
@ -65,3 +66,477 @@ class Doris(MySQL):
|
|||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
}
|
||||
|
||||
# https://github.com/apache/doris/blob/e4f41dbf1ec03f5937fdeba2ee1454a20254015b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4#L93
|
||||
RESERVED_KEYWORDS = {
|
||||
"account_lock",
|
||||
"account_unlock",
|
||||
"add",
|
||||
"adddate",
|
||||
"admin",
|
||||
"after",
|
||||
"agg_state",
|
||||
"aggregate",
|
||||
"alias",
|
||||
"all",
|
||||
"alter",
|
||||
"analyze",
|
||||
"analyzed",
|
||||
"and",
|
||||
"anti",
|
||||
"append",
|
||||
"array",
|
||||
"array_range",
|
||||
"as",
|
||||
"asc",
|
||||
"at",
|
||||
"authors",
|
||||
"auto",
|
||||
"auto_increment",
|
||||
"backend",
|
||||
"backends",
|
||||
"backup",
|
||||
"begin",
|
||||
"belong",
|
||||
"between",
|
||||
"bigint",
|
||||
"bin",
|
||||
"binary",
|
||||
"binlog",
|
||||
"bitand",
|
||||
"bitmap",
|
||||
"bitmap_union",
|
||||
"bitor",
|
||||
"bitxor",
|
||||
"blob",
|
||||
"boolean",
|
||||
"brief",
|
||||
"broker",
|
||||
"buckets",
|
||||
"build",
|
||||
"builtin",
|
||||
"bulk",
|
||||
"by",
|
||||
"cached",
|
||||
"call",
|
||||
"cancel",
|
||||
"case",
|
||||
"cast",
|
||||
"catalog",
|
||||
"catalogs",
|
||||
"chain",
|
||||
"char",
|
||||
"character",
|
||||
"charset",
|
||||
"check",
|
||||
"clean",
|
||||
"cluster",
|
||||
"clusters",
|
||||
"collate",
|
||||
"collation",
|
||||
"collect",
|
||||
"column",
|
||||
"columns",
|
||||
"comment",
|
||||
"commit",
|
||||
"committed",
|
||||
"compact",
|
||||
"complete",
|
||||
"config",
|
||||
"connection",
|
||||
"connection_id",
|
||||
"consistent",
|
||||
"constraint",
|
||||
"constraints",
|
||||
"convert",
|
||||
"copy",
|
||||
"count",
|
||||
"create",
|
||||
"creation",
|
||||
"cron",
|
||||
"cross",
|
||||
"cube",
|
||||
"current",
|
||||
"current_catalog",
|
||||
"current_date",
|
||||
"current_time",
|
||||
"current_timestamp",
|
||||
"current_user",
|
||||
"data",
|
||||
"database",
|
||||
"databases",
|
||||
"date",
|
||||
"date_add",
|
||||
"date_ceil",
|
||||
"date_diff",
|
||||
"date_floor",
|
||||
"date_sub",
|
||||
"dateadd",
|
||||
"datediff",
|
||||
"datetime",
|
||||
"datetimev2",
|
||||
"datev2",
|
||||
"datetimev1",
|
||||
"datev1",
|
||||
"day",
|
||||
"days_add",
|
||||
"days_sub",
|
||||
"decimal",
|
||||
"decimalv2",
|
||||
"decimalv3",
|
||||
"decommission",
|
||||
"default",
|
||||
"deferred",
|
||||
"delete",
|
||||
"demand",
|
||||
"desc",
|
||||
"describe",
|
||||
"diagnose",
|
||||
"disk",
|
||||
"distinct",
|
||||
"distinctpc",
|
||||
"distinctpcsa",
|
||||
"distributed",
|
||||
"distribution",
|
||||
"div",
|
||||
"do",
|
||||
"doris_internal_table_id",
|
||||
"double",
|
||||
"drop",
|
||||
"dropp",
|
||||
"dual",
|
||||
"duplicate",
|
||||
"dynamic",
|
||||
"else",
|
||||
"enable",
|
||||
"encryptkey",
|
||||
"encryptkeys",
|
||||
"end",
|
||||
"ends",
|
||||
"engine",
|
||||
"engines",
|
||||
"enter",
|
||||
"errors",
|
||||
"events",
|
||||
"every",
|
||||
"except",
|
||||
"exclude",
|
||||
"execute",
|
||||
"exists",
|
||||
"expired",
|
||||
"explain",
|
||||
"export",
|
||||
"extended",
|
||||
"external",
|
||||
"extract",
|
||||
"failed_login_attempts",
|
||||
"false",
|
||||
"fast",
|
||||
"feature",
|
||||
"fields",
|
||||
"file",
|
||||
"filter",
|
||||
"first",
|
||||
"float",
|
||||
"follower",
|
||||
"following",
|
||||
"for",
|
||||
"foreign",
|
||||
"force",
|
||||
"format",
|
||||
"free",
|
||||
"from",
|
||||
"frontend",
|
||||
"frontends",
|
||||
"full",
|
||||
"function",
|
||||
"functions",
|
||||
"generic",
|
||||
"global",
|
||||
"grant",
|
||||
"grants",
|
||||
"graph",
|
||||
"group",
|
||||
"grouping",
|
||||
"groups",
|
||||
"hash",
|
||||
"having",
|
||||
"hdfs",
|
||||
"help",
|
||||
"histogram",
|
||||
"hll",
|
||||
"hll_union",
|
||||
"hostname",
|
||||
"hour",
|
||||
"hub",
|
||||
"identified",
|
||||
"if",
|
||||
"ignore",
|
||||
"immediate",
|
||||
"in",
|
||||
"incremental",
|
||||
"index",
|
||||
"indexes",
|
||||
"infile",
|
||||
"inner",
|
||||
"insert",
|
||||
"install",
|
||||
"int",
|
||||
"integer",
|
||||
"intermediate",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"inverted",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"is",
|
||||
"is_not_null_pred",
|
||||
"is_null_pred",
|
||||
"isnull",
|
||||
"isolation",
|
||||
"job",
|
||||
"jobs",
|
||||
"join",
|
||||
"json",
|
||||
"jsonb",
|
||||
"key",
|
||||
"keys",
|
||||
"kill",
|
||||
"label",
|
||||
"largeint",
|
||||
"last",
|
||||
"lateral",
|
||||
"ldap",
|
||||
"ldap_admin_password",
|
||||
"left",
|
||||
"less",
|
||||
"level",
|
||||
"like",
|
||||
"limit",
|
||||
"lines",
|
||||
"link",
|
||||
"list",
|
||||
"load",
|
||||
"local",
|
||||
"localtime",
|
||||
"localtimestamp",
|
||||
"location",
|
||||
"lock",
|
||||
"logical",
|
||||
"low_priority",
|
||||
"manual",
|
||||
"map",
|
||||
"match",
|
||||
"match_all",
|
||||
"match_any",
|
||||
"match_phrase",
|
||||
"match_phrase_edge",
|
||||
"match_phrase_prefix",
|
||||
"match_regexp",
|
||||
"materialized",
|
||||
"max",
|
||||
"maxvalue",
|
||||
"memo",
|
||||
"merge",
|
||||
"migrate",
|
||||
"migrations",
|
||||
"min",
|
||||
"minus",
|
||||
"minute",
|
||||
"modify",
|
||||
"month",
|
||||
"mtmv",
|
||||
"name",
|
||||
"names",
|
||||
"natural",
|
||||
"negative",
|
||||
"never",
|
||||
"next",
|
||||
"ngram_bf",
|
||||
"no",
|
||||
"non_nullable",
|
||||
"not",
|
||||
"null",
|
||||
"nulls",
|
||||
"observer",
|
||||
"of",
|
||||
"offset",
|
||||
"on",
|
||||
"only",
|
||||
"open",
|
||||
"optimized",
|
||||
"or",
|
||||
"order",
|
||||
"outer",
|
||||
"outfile",
|
||||
"over",
|
||||
"overwrite",
|
||||
"parameter",
|
||||
"parsed",
|
||||
"partition",
|
||||
"partitions",
|
||||
"password",
|
||||
"password_expire",
|
||||
"password_history",
|
||||
"password_lock_time",
|
||||
"password_reuse",
|
||||
"path",
|
||||
"pause",
|
||||
"percent",
|
||||
"period",
|
||||
"permissive",
|
||||
"physical",
|
||||
"plan",
|
||||
"process",
|
||||
"plugin",
|
||||
"plugins",
|
||||
"policy",
|
||||
"preceding",
|
||||
"prepare",
|
||||
"primary",
|
||||
"proc",
|
||||
"procedure",
|
||||
"processlist",
|
||||
"profile",
|
||||
"properties",
|
||||
"property",
|
||||
"quantile_state",
|
||||
"quantile_union",
|
||||
"query",
|
||||
"quota",
|
||||
"random",
|
||||
"range",
|
||||
"read",
|
||||
"real",
|
||||
"rebalance",
|
||||
"recover",
|
||||
"recycle",
|
||||
"refresh",
|
||||
"references",
|
||||
"regexp",
|
||||
"release",
|
||||
"rename",
|
||||
"repair",
|
||||
"repeatable",
|
||||
"replace",
|
||||
"replace_if_not_null",
|
||||
"replica",
|
||||
"repositories",
|
||||
"repository",
|
||||
"resource",
|
||||
"resources",
|
||||
"restore",
|
||||
"restrictive",
|
||||
"resume",
|
||||
"returns",
|
||||
"revoke",
|
||||
"rewritten",
|
||||
"right",
|
||||
"rlike",
|
||||
"role",
|
||||
"roles",
|
||||
"rollback",
|
||||
"rollup",
|
||||
"routine",
|
||||
"row",
|
||||
"rows",
|
||||
"s3",
|
||||
"sample",
|
||||
"schedule",
|
||||
"scheduler",
|
||||
"schema",
|
||||
"schemas",
|
||||
"second",
|
||||
"select",
|
||||
"semi",
|
||||
"sequence",
|
||||
"serializable",
|
||||
"session",
|
||||
"set",
|
||||
"sets",
|
||||
"shape",
|
||||
"show",
|
||||
"signed",
|
||||
"skew",
|
||||
"smallint",
|
||||
"snapshot",
|
||||
"soname",
|
||||
"split",
|
||||
"sql_block_rule",
|
||||
"start",
|
||||
"starts",
|
||||
"stats",
|
||||
"status",
|
||||
"stop",
|
||||
"storage",
|
||||
"stream",
|
||||
"streaming",
|
||||
"string",
|
||||
"struct",
|
||||
"subdate",
|
||||
"sum",
|
||||
"superuser",
|
||||
"switch",
|
||||
"sync",
|
||||
"system",
|
||||
"table",
|
||||
"tables",
|
||||
"tablesample",
|
||||
"tablet",
|
||||
"tablets",
|
||||
"task",
|
||||
"tasks",
|
||||
"temporary",
|
||||
"terminated",
|
||||
"text",
|
||||
"than",
|
||||
"then",
|
||||
"time",
|
||||
"timestamp",
|
||||
"timestampadd",
|
||||
"timestampdiff",
|
||||
"tinyint",
|
||||
"to",
|
||||
"transaction",
|
||||
"trash",
|
||||
"tree",
|
||||
"triggers",
|
||||
"trim",
|
||||
"true",
|
||||
"truncate",
|
||||
"type",
|
||||
"type_cast",
|
||||
"types",
|
||||
"unbounded",
|
||||
"uncommitted",
|
||||
"uninstall",
|
||||
"union",
|
||||
"unique",
|
||||
"unlock",
|
||||
"unsigned",
|
||||
"update",
|
||||
"use",
|
||||
"user",
|
||||
"using",
|
||||
"value",
|
||||
"values",
|
||||
"varchar",
|
||||
"variables",
|
||||
"variant",
|
||||
"vault",
|
||||
"verbose",
|
||||
"version",
|
||||
"view",
|
||||
"warnings",
|
||||
"week",
|
||||
"when",
|
||||
"where",
|
||||
"whitelist",
|
||||
"with",
|
||||
"work",
|
||||
"workload",
|
||||
"write",
|
||||
"xor",
|
||||
"year",
|
||||
}
|
||||
|
|
|
@ -573,6 +573,9 @@ class Hive(Dialect):
|
|||
exp.OnProperty: lambda *_: "",
|
||||
exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY",
|
||||
exp.ParseJSON: lambda self, e: self.sql(e.this),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -670,6 +670,7 @@ class MySQL(Dialect):
|
|||
return self.expression(exp.GroupConcat, this=this, separator=separator)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = None
|
||||
JOIN_HINTS = False
|
||||
|
|
|
@ -116,7 +116,10 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st
|
|||
|
||||
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]" if expression.expressions else "ARRAY"
|
||||
if expression.expressions:
|
||||
values = self.expressions(expression, key="values", flat=True)
|
||||
return f"{self.expressions(expression, flat=True)}[{values}]"
|
||||
return "ARRAY"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
|
@ -333,6 +336,7 @@ class Postgres(Dialect):
|
|||
"REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGROLE": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
|
||||
"FLOAT": TokenType.DOUBLE,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
|
|
@ -63,6 +63,9 @@ class Redshift(Postgres):
|
|||
"DATE_DIFF": _build_date_delta(exp.TsOrDsDiff),
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"SPLIT_TO_ARRAY": lambda args: exp.StringToArray(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(",")
|
||||
),
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
|
@ -124,6 +127,7 @@ class Redshift(Postgres):
|
|||
"TOP": TokenType.TOP,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
}
|
||||
KEYWORDS.pop("VALUES")
|
||||
|
||||
|
@ -186,6 +190,7 @@ class Redshift(Postgres):
|
|||
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.StartsWith: lambda self,
|
||||
e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
|
||||
exp.StringToArray: rename_func("SPLIT_TO_ARRAY"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
|
|
|
@ -473,6 +473,14 @@ class Snowflake(Dialect):
|
|||
"TERSE USERS": _show_parser("USERS"),
|
||||
}
|
||||
|
||||
CONSTRAINT_PARSERS = {
|
||||
**parser.Parser.CONSTRAINT_PARSERS,
|
||||
"WITH": lambda self: self._parse_with_constraint(),
|
||||
"MASKING": lambda self: self._parse_with_constraint(),
|
||||
"PROJECTION": lambda self: self._parse_with_constraint(),
|
||||
"TAG": lambda self: self._parse_with_constraint(),
|
||||
}
|
||||
|
||||
STAGED_FILE_SINGLE_TOKENS = {
|
||||
TokenType.DOT,
|
||||
TokenType.MOD,
|
||||
|
@ -497,6 +505,29 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
def _parse_with_constraint(self) -> t.Optional[exp.Expression]:
|
||||
if self._prev.token_type != TokenType.WITH:
|
||||
self._retreat(self._index - 1)
|
||||
|
||||
if self._match_text_seq("MASKING", "POLICY"):
|
||||
return self.expression(
|
||||
exp.MaskingPolicyColumnConstraint,
|
||||
this=self._parse_id_var(),
|
||||
expressions=self._match(TokenType.USING)
|
||||
and self._parse_wrapped_csv(self._parse_id_var),
|
||||
)
|
||||
if self._match_text_seq("PROJECTION", "POLICY"):
|
||||
return self.expression(
|
||||
exp.ProjectionPolicyColumnConstraint, this=self._parse_id_var()
|
||||
)
|
||||
if self._match(TokenType.TAG):
|
||||
return self.expression(
|
||||
exp.TagColumnConstraint,
|
||||
expressions=self._parse_wrapped_csv(self._parse_property),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_create(self) -> exp.Create | exp.Command:
|
||||
expression = super()._parse_create()
|
||||
if isinstance(expression, exp.Create) and expression.kind in self.NON_TABLE_CREATABLES:
|
||||
|
|
|
@ -17,7 +17,6 @@ from sqlglot.dialects.dialect import (
|
|||
min_or_least,
|
||||
build_date_delta,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -818,6 +817,7 @@ class TSQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
|
||||
exp.Repeat: rename_func("REPLICATE"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
|
@ -834,7 +834,9 @@ class TSQL(Dialect):
|
|||
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
|
||||
),
|
||||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(
|
||||
exp.cast(e.this, exp.DataType.Type.DATETIME)
|
||||
),
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.Trim: trim_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
|
|
|
@ -1632,6 +1632,7 @@ class AlterColumn(Expression):
|
|||
"default": False,
|
||||
"drop": False,
|
||||
"comment": False,
|
||||
"allow_null": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1835,6 +1836,11 @@ class NotForReplicationColumnConstraint(ColumnConstraintKind):
|
|||
arg_types = {}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-table
|
||||
class MaskingPolicyColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
||||
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"allow_null": False}
|
||||
|
||||
|
@ -1844,6 +1850,11 @@ class OnUpdateColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-table
|
||||
class TagColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-external-table#optional-parameters
|
||||
class TransformColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
@ -1869,6 +1880,11 @@ class PathColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-table
|
||||
class ProjectionPolicyColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
# computed column expression
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
|
||||
class ComputedColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -1992,7 +2008,7 @@ class Connect(Expression):
|
|||
|
||||
|
||||
class CopyParameter(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
arg_types = {"this": True, "expression": False, "expressions": False}
|
||||
|
||||
|
||||
class Copy(Expression):
|
||||
|
@ -4825,6 +4841,11 @@ class ArrayToString(Func):
|
|||
_sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"]
|
||||
|
||||
|
||||
class StringToArray(Func):
|
||||
arg_types = {"this": True, "expression": True, "null": False}
|
||||
_sql_names = ["STRING_TO_ARRAY", "SPLIT_BY_STRING"]
|
||||
|
||||
|
||||
class ArrayOverlaps(Binary, Func):
|
||||
pass
|
||||
|
||||
|
|
|
@ -123,6 +123,8 @@ class Generator(metaclass=_Generator):
|
|||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.ProjectionPolicyColumnConstraint: lambda self,
|
||||
e: f"PROJECTION POLICY {self.sql(e, 'this')}",
|
||||
exp.RemoteWithConnectionModelProperty: lambda self,
|
||||
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
|
||||
exp.ReturnsProperty: lambda self, e: (
|
||||
|
@ -139,6 +141,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.StabilityProperty: lambda _, e: e.name,
|
||||
exp.StrictProperty: lambda *_: "STRICT",
|
||||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
|
||||
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
|
||||
|
@ -3022,9 +3025,16 @@ class Generator(metaclass=_Generator):
|
|||
if comment:
|
||||
return f"ALTER COLUMN {this} COMMENT {comment}"
|
||||
|
||||
if not expression.args.get("drop"):
|
||||
allow_null = expression.args.get("allow_null")
|
||||
drop = expression.args.get("drop")
|
||||
|
||||
if not drop and not allow_null:
|
||||
self.unsupported("Unsupported ALTER COLUMN syntax")
|
||||
|
||||
if allow_null is not None:
|
||||
keyword = "DROP" if drop else "SET"
|
||||
return f"ALTER COLUMN {this} {keyword} NOT NULL"
|
||||
|
||||
return f"ALTER COLUMN {this} DROP DEFAULT"
|
||||
|
||||
def alterdiststyle_sql(self, expression: exp.AlterDistStyle) -> str:
|
||||
|
@ -3850,9 +3860,16 @@ class Generator(metaclass=_Generator):
|
|||
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
|
||||
option = self.sql(expression, "this")
|
||||
|
||||
if option.upper() == "FILE_FORMAT":
|
||||
values = self.expressions(expression, key="expression", flat=True, sep=" ")
|
||||
return f"{option} = ({values})"
|
||||
if expression.expressions:
|
||||
upper = option.upper()
|
||||
|
||||
# Snowflake FILE_FORMAT options are separated by whitespace
|
||||
sep = " " if upper == "FILE_FORMAT" else ", "
|
||||
|
||||
# Databricks copy/format options do not set their list of values with EQ
|
||||
op = " " if upper in ("COPY_OPTIONS", "FORMAT_OPTIONS") else " = "
|
||||
values = self.expressions(expression, flat=True, sep=sep)
|
||||
return f"{option}{op}({values})"
|
||||
|
||||
value = self.sql(expression, "expression")
|
||||
|
||||
|
@ -3872,9 +3889,10 @@ class Generator(metaclass=_Generator):
|
|||
else:
|
||||
# Snowflake case: CREDENTIALS = (...)
|
||||
credentials = self.expressions(expression, key="credentials", flat=True, sep=" ")
|
||||
credentials = f"CREDENTIALS = ({credentials})" if credentials else ""
|
||||
credentials = f"CREDENTIALS = ({credentials})" if cred_expr is not None else ""
|
||||
|
||||
storage = self.sql(expression, "storage")
|
||||
storage = f"STORAGE_INTEGRATION = {storage}" if storage else ""
|
||||
|
||||
encryption = self.expressions(expression, key="encryption", flat=True, sep=" ")
|
||||
encryption = f" ENCRYPTION = ({encryption})" if encryption else ""
|
||||
|
@ -3929,3 +3947,11 @@ class Generator(metaclass=_Generator):
|
|||
on_sql = self.func("ON", filter_col, retention_period)
|
||||
|
||||
return f"DATA_DELETION={on_sql}"
|
||||
|
||||
def maskingpolicycolumnconstraint_sql(
|
||||
self, expression: exp.MaskingPolicyColumnConstraint
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" USING ({expressions})" if expressions else ""
|
||||
return f"MASKING POLICY {this}{expressions}"
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import inspect
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import Schema, exp
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
@ -72,7 +71,7 @@ def optimize(
|
|||
Returns:
|
||||
The optimized expression.
|
||||
"""
|
||||
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
possible_kwargs = {
|
||||
"db": db,
|
||||
"catalog": catalog,
|
||||
|
|
|
@ -63,6 +63,7 @@ def qualify_columns(
|
|||
if schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
|
||||
_convert_columns_to_dots(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
|
||||
if not schema.empty and expand_alias_refs:
|
||||
|
@ -70,7 +71,13 @@ def qualify_columns(
|
|||
|
||||
if not isinstance(scope.expression, exp.UDTF):
|
||||
if expand_stars:
|
||||
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
|
||||
_expand_stars(
|
||||
scope,
|
||||
resolver,
|
||||
using_column_tables,
|
||||
pseudocolumns,
|
||||
annotator,
|
||||
)
|
||||
qualify_outputs(scope)
|
||||
|
||||
_expand_group_by(scope)
|
||||
|
@ -329,6 +336,47 @@ def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
|
|||
raise OptimizeError(f"Unknown output column: {node.name}")
|
||||
|
||||
|
||||
def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None:
|
||||
"""
|
||||
Converts `Column` instances that represent struct field lookup into chained `Dots`.
|
||||
|
||||
Struct field lookups look like columns (e.g. "struct"."field"), but they need to be
|
||||
qualified separately and represented as Dot(Dot(...(<table>.<column>, field1), field2, ...)).
|
||||
"""
|
||||
converted = False
|
||||
for column in itertools.chain(scope.columns, scope.stars):
|
||||
if isinstance(column, exp.Dot):
|
||||
continue
|
||||
|
||||
column_table: t.Optional[str | exp.Identifier] = column.table
|
||||
if (
|
||||
column_table
|
||||
and column_table not in scope.sources
|
||||
and (
|
||||
not scope.parent
|
||||
or column_table not in scope.parent.sources
|
||||
or not scope.is_correlated_subquery
|
||||
)
|
||||
):
|
||||
root, *parts = column.parts
|
||||
|
||||
if root.name in scope.sources:
|
||||
# The struct is already qualified, but we still need to change the AST
|
||||
column_table = root
|
||||
root, *parts = parts
|
||||
else:
|
||||
column_table = resolver.get_table(root.name)
|
||||
|
||||
if column_table:
|
||||
converted = True
|
||||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
if converted:
|
||||
# We want to re-aggregate the converted columns, otherwise they'd be skipped in
|
||||
# a `for column in scope.columns` iteration, even though they shouldn't be
|
||||
scope.clear_cache()
|
||||
|
||||
|
||||
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
||||
"""Disambiguate columns, ensuring each column specifies a source"""
|
||||
for column in scope.columns:
|
||||
|
@ -347,30 +395,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
column.set("table", exp.to_identifier(scope.pivots[0].alias))
|
||||
continue
|
||||
|
||||
column_table = resolver.get_table(column_name)
|
||||
|
||||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
column_table = resolver.get_table(column_name)
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources and (
|
||||
not scope.parent
|
||||
or column_table not in scope.parent.sources
|
||||
or not scope.is_correlated_subquery
|
||||
):
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
||||
root, *parts = column.parts
|
||||
|
||||
if root.name in scope.sources:
|
||||
# struct is already qualified, but we still need to change the AST representation
|
||||
column_table = root
|
||||
root, *parts = parts
|
||||
else:
|
||||
column_table = resolver.get_table(root.name)
|
||||
|
||||
if column_table:
|
||||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
for pivot in scope.pivots:
|
||||
for column in pivot.find_all(exp.Column):
|
||||
|
@ -380,11 +408,64 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
|
|||
column.set("table", column_table)
|
||||
|
||||
|
||||
def _expand_struct_stars(
|
||||
expression: exp.Dot,
|
||||
) -> t.List[exp.Alias]:
|
||||
"""[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column"""
|
||||
|
||||
dot_column = t.cast(exp.Column, expression.find(exp.Column))
|
||||
if not dot_column.is_type(exp.DataType.Type.STRUCT):
|
||||
return []
|
||||
|
||||
# All nested struct values are ColumnDefs, so normalize the first exp.Column in one
|
||||
dot_column = dot_column.copy()
|
||||
starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type)
|
||||
|
||||
# First part is the table name and last part is the star so they can be dropped
|
||||
dot_parts = expression.parts[1:-1]
|
||||
|
||||
# If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case)
|
||||
for part in dot_parts[1:]:
|
||||
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
|
||||
# Unable to expand star unless all fields are named
|
||||
if not isinstance(field.this, exp.Identifier):
|
||||
return []
|
||||
|
||||
if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT):
|
||||
starting_struct = field
|
||||
break
|
||||
else:
|
||||
# There is no matching field in the struct
|
||||
return []
|
||||
|
||||
taken_names = set()
|
||||
new_selections = []
|
||||
|
||||
for field in t.cast(exp.DataType, starting_struct.kind).expressions:
|
||||
name = field.name
|
||||
|
||||
# Ambiguous or anonymous fields can't be expanded
|
||||
if name in taken_names or not isinstance(field.this, exp.Identifier):
|
||||
return []
|
||||
|
||||
taken_names.add(name)
|
||||
|
||||
this = field.this.copy()
|
||||
root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])]
|
||||
new_column = exp.column(
|
||||
t.cast(exp.Identifier, root), table=dot_column.args.get("table"), fields=parts
|
||||
)
|
||||
new_selections.append(alias(new_column, this, copy=False))
|
||||
|
||||
return new_selections
|
||||
|
||||
|
||||
def _expand_stars(
|
||||
scope: Scope,
|
||||
resolver: Resolver,
|
||||
using_column_tables: t.Dict[str, t.Any],
|
||||
pseudocolumns: t.Set[str],
|
||||
annotator: TypeAnnotator,
|
||||
) -> None:
|
||||
"""Expand stars to lists of column selections"""
|
||||
|
||||
|
@ -392,6 +473,7 @@ def _expand_stars(
|
|||
except_columns: t.Dict[int, t.Set[str]] = {}
|
||||
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
|
||||
coalesced_columns = set()
|
||||
dialect = resolver.schema.dialect
|
||||
|
||||
pivot_output_columns = None
|
||||
pivot_exclude_columns = None
|
||||
|
@ -413,16 +495,29 @@ def _expand_stars(
|
|||
if not pivot_output_columns:
|
||||
pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
|
||||
|
||||
is_bigquery = dialect == "bigquery"
|
||||
if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars):
|
||||
# Found struct expansion, annotate scope ahead of time
|
||||
annotator.annotate_scope(scope)
|
||||
|
||||
for expression in scope.expression.selects:
|
||||
tables = []
|
||||
if isinstance(expression, exp.Star):
|
||||
tables = list(scope.selected_sources)
|
||||
tables.extend(scope.selected_sources)
|
||||
_add_except_columns(expression, tables, except_columns)
|
||||
_add_replace_columns(expression, tables, replace_columns)
|
||||
elif expression.is_star and not isinstance(expression, exp.Dot):
|
||||
tables = [expression.table]
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
else:
|
||||
elif expression.is_star:
|
||||
if not isinstance(expression, exp.Dot):
|
||||
tables.append(expression.table)
|
||||
_add_except_columns(expression.this, tables, except_columns)
|
||||
_add_replace_columns(expression.this, tables, replace_columns)
|
||||
elif is_bigquery:
|
||||
struct_fields = _expand_struct_stars(expression)
|
||||
if struct_fields:
|
||||
new_selections.extend(struct_fields)
|
||||
continue
|
||||
|
||||
if not tables:
|
||||
new_selections.append(expression)
|
||||
continue
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ class Scope:
|
|||
def clear_cache(self):
|
||||
self._collected = False
|
||||
self._raw_columns = None
|
||||
self._stars = None
|
||||
self._derived_tables = None
|
||||
self._udtfs = None
|
||||
self._tables = None
|
||||
|
@ -119,14 +120,20 @@ class Scope:
|
|||
self._derived_tables = []
|
||||
self._udtfs = []
|
||||
self._raw_columns = []
|
||||
self._stars = []
|
||||
self._join_hints = []
|
||||
|
||||
for node in self.walk(bfs=False):
|
||||
if node is self.expression:
|
||||
continue
|
||||
|
||||
if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
|
||||
self._raw_columns.append(node)
|
||||
if isinstance(node, exp.Dot) and node.is_star:
|
||||
self._stars.append(node)
|
||||
elif isinstance(node, exp.Column):
|
||||
if isinstance(node.this, exp.Star):
|
||||
self._stars.append(node)
|
||||
else:
|
||||
self._raw_columns.append(node)
|
||||
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
|
||||
self._tables.append(node)
|
||||
elif isinstance(node, exp.JoinHint):
|
||||
|
@ -231,6 +238,14 @@ class Scope:
|
|||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
||||
@property
|
||||
def stars(self) -> t.List[exp.Column | exp.Dot]:
|
||||
"""
|
||||
List of star expressions (columns or dots) in this scope.
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._stars
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
"""
|
||||
|
|
|
@ -1134,6 +1134,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT}
|
||||
|
||||
COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"}
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
PREFIXED_PIVOT_COLUMNS = False
|
||||
|
@ -1830,11 +1832,17 @@ class Parser(metaclass=_Parser):
|
|||
self._retreat(index)
|
||||
return self._parse_sequence_properties()
|
||||
|
||||
return self.expression(
|
||||
exp.Property,
|
||||
this=key.to_dot() if isinstance(key, exp.Column) else key,
|
||||
value=self._parse_bitwise() or self._parse_var(any_token=True),
|
||||
)
|
||||
# Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise
|
||||
if isinstance(key, exp.Column):
|
||||
key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name)
|
||||
|
||||
value = self._parse_bitwise() or self._parse_var(any_token=True)
|
||||
|
||||
# Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier())
|
||||
if isinstance(value, exp.Column):
|
||||
value = exp.var(value.name)
|
||||
|
||||
return self.expression(exp.Property, this=key, value=value)
|
||||
|
||||
def _parse_stored(self) -> exp.FileFormatProperty:
|
||||
self._match(TokenType.ALIAS)
|
||||
|
@ -1853,7 +1861,7 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
)
|
||||
|
||||
def _parse_unquoted_field(self):
|
||||
def _parse_unquoted_field(self) -> t.Optional[exp.Expression]:
|
||||
field = self._parse_field()
|
||||
if isinstance(field, exp.Identifier) and not field.quoted:
|
||||
field = exp.var(field)
|
||||
|
@ -2793,7 +2801,13 @@ class Parser(metaclass=_Parser):
|
|||
if not alias and not columns:
|
||||
return None
|
||||
|
||||
return self.expression(exp.TableAlias, this=alias, columns=columns)
|
||||
table_alias = self.expression(exp.TableAlias, this=alias, columns=columns)
|
||||
|
||||
# We bubble up comments from the Identifier to the TableAlias
|
||||
if isinstance(alias, exp.Identifier):
|
||||
table_alias.add_comments(alias.pop_comments())
|
||||
|
||||
return table_alias
|
||||
|
||||
def _parse_subquery(
|
||||
self, this: t.Optional[exp.Expression], parse_alias: bool = True
|
||||
|
@ -4060,7 +4074,7 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
return self.expression(exp.Escape, this=this, expression=self._parse_string())
|
||||
|
||||
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
|
||||
def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Add | exp.Interval]:
|
||||
index = self._index
|
||||
|
||||
if not self._match(TokenType.INTERVAL) and match_interval:
|
||||
|
@ -4090,23 +4104,33 @@ class Parser(metaclass=_Parser):
|
|||
if this and this.is_number:
|
||||
this = exp.Literal.string(this.name)
|
||||
elif this and this.is_string:
|
||||
parts = this.name.split()
|
||||
|
||||
if len(parts) == 2:
|
||||
parts = exp.INTERVAL_STRING_RE.findall(this.name)
|
||||
if len(parts) == 1:
|
||||
if unit:
|
||||
# This is not actually a unit, it's something else (e.g. a "window side")
|
||||
unit = None
|
||||
# Unconsume the eagerly-parsed unit, since the real unit was part of the string
|
||||
self._retreat(self._index - 1)
|
||||
|
||||
this = exp.Literal.string(parts[0])
|
||||
unit = self.expression(exp.Var, this=parts[1].upper())
|
||||
this = exp.Literal.string(parts[0][0])
|
||||
unit = self.expression(exp.Var, this=parts[0][1].upper())
|
||||
|
||||
if self.INTERVAL_SPANS and self._match_text_seq("TO"):
|
||||
unit = self.expression(
|
||||
exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True)
|
||||
)
|
||||
|
||||
return self.expression(exp.Interval, this=this, unit=unit)
|
||||
interval = self.expression(exp.Interval, this=this, unit=unit)
|
||||
|
||||
index = self._index
|
||||
self._match(TokenType.PLUS)
|
||||
|
||||
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
|
||||
if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
|
||||
return self.expression(
|
||||
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
|
||||
)
|
||||
|
||||
self._retreat(index)
|
||||
return interval
|
||||
|
||||
def _parse_bitwise(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_term()
|
||||
|
@ -4173,38 +4197,45 @@ class Parser(metaclass=_Parser):
|
|||
) -> t.Optional[exp.Expression]:
|
||||
interval = parse_interval and self._parse_interval()
|
||||
if interval:
|
||||
# Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals
|
||||
while True:
|
||||
index = self._index
|
||||
self._match(TokenType.PLUS)
|
||||
|
||||
if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
|
||||
self._retreat(index)
|
||||
break
|
||||
|
||||
interval = self.expression( # type: ignore
|
||||
exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
|
||||
)
|
||||
|
||||
return interval
|
||||
|
||||
index = self._index
|
||||
data_type = self._parse_types(check_func=True, allow_identifiers=False)
|
||||
this = self._parse_column()
|
||||
|
||||
if data_type:
|
||||
index2 = self._index
|
||||
this = self._parse_primary()
|
||||
|
||||
if isinstance(this, exp.Literal):
|
||||
parser = self.TYPE_LITERAL_PARSERS.get(data_type.this)
|
||||
if parser:
|
||||
return parser(self, this, data_type)
|
||||
|
||||
return self.expression(exp.Cast, this=this, to=data_type)
|
||||
|
||||
if not data_type.expressions:
|
||||
self._retreat(index)
|
||||
return self._parse_id_var() if fallback_to_identifier else self._parse_column()
|
||||
# The expressions arg gets set by the parser when we have something like DECIMAL(38, 0)
|
||||
# in the input SQL. In that case, we'll produce these tokens: DECIMAL ( 38 , 0 )
|
||||
#
|
||||
# If the index difference here is greater than 1, that means the parser itself must have
|
||||
# consumed additional tokens such as the DECIMAL scale and precision in the above example.
|
||||
#
|
||||
# If it's not greater than 1, then it must be 1, because we've consumed at least the type
|
||||
# keyword, meaning that the expressions arg of the DataType must have gotten set by a
|
||||
# callable in the TYPE_CONVERTERS mapping. For example, Snowflake converts DECIMAL to
|
||||
# DECIMAL(38, 0)) in order to facilitate the data type's transpilation.
|
||||
#
|
||||
# In these cases, we don't really want to return the converted type, but instead retreat
|
||||
# and try to parse a Column or Identifier in the section below.
|
||||
if data_type.expressions and index2 - index > 1:
|
||||
self._retreat(index2)
|
||||
return self._parse_column_ops(data_type)
|
||||
|
||||
return self._parse_column_ops(data_type)
|
||||
self._retreat(index)
|
||||
|
||||
if fallback_to_identifier:
|
||||
return self._parse_id_var()
|
||||
|
||||
this = self._parse_column()
|
||||
return this and self._parse_column_ops(this)
|
||||
|
||||
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
|
||||
|
@ -4268,7 +4299,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
|
||||
elif nested:
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(
|
||||
|
@ -4369,8 +4400,26 @@ class Parser(metaclass=_Parser):
|
|||
elif expressions:
|
||||
this.set("expressions", expressions)
|
||||
|
||||
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
|
||||
index = self._index
|
||||
|
||||
# Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3]
|
||||
matched_array = self._match(TokenType.ARRAY)
|
||||
|
||||
while self._curr:
|
||||
matched_l_bracket = self._match(TokenType.L_BRACKET)
|
||||
if not matched_l_bracket and not matched_array:
|
||||
break
|
||||
|
||||
matched_array = False
|
||||
values = self._parse_csv(self._parse_conjunction) or None
|
||||
if values and not schema:
|
||||
self._retreat(index)
|
||||
break
|
||||
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[this], values=values, nested=True
|
||||
)
|
||||
self._match(TokenType.R_BRACKET)
|
||||
|
||||
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
|
||||
converter = self.TYPE_CONVERTER.get(this.this)
|
||||
|
@ -4386,15 +4435,16 @@ class Parser(metaclass=_Parser):
|
|||
or self._parse_id_var()
|
||||
)
|
||||
self._match(TokenType.COLON)
|
||||
column_def = self._parse_column_def(this)
|
||||
|
||||
if type_required and (
|
||||
(isinstance(this, exp.Column) and this.this is column_def) or this is column_def
|
||||
if (
|
||||
type_required
|
||||
and not isinstance(this, exp.DataType)
|
||||
and not self._match_set(self.TYPE_TOKENS, advance=False)
|
||||
):
|
||||
self._retreat(index)
|
||||
return self._parse_types()
|
||||
|
||||
return column_def
|
||||
return self._parse_column_def(this)
|
||||
|
||||
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("AT", "TIME", "ZONE"):
|
||||
|
@ -6030,7 +6080,19 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
|
||||
if self._match(TokenType.COMMENT):
|
||||
return self.expression(exp.AlterColumn, this=column, comment=self._parse_string())
|
||||
|
||||
if self._match_text_seq("DROP", "NOT", "NULL"):
|
||||
return self.expression(
|
||||
exp.AlterColumn,
|
||||
this=column,
|
||||
drop=True,
|
||||
allow_null=True,
|
||||
)
|
||||
if self._match_text_seq("SET", "NOT", "NULL"):
|
||||
return self.expression(
|
||||
exp.AlterColumn,
|
||||
this=column,
|
||||
allow_null=False,
|
||||
)
|
||||
self._match_text_seq("SET", "DATA")
|
||||
self._match_text_seq("TYPE")
|
||||
return self.expression(
|
||||
|
@ -6595,12 +6657,23 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.WithOperator, this=this, op=op)
|
||||
|
||||
def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
opts = []
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.L_PAREN)
|
||||
|
||||
opts: t.List[t.Optional[exp.Expression]] = []
|
||||
while self._curr and not self._match(TokenType.R_PAREN):
|
||||
opts.append(self._parse_conjunction())
|
||||
if self._match_text_seq("FORMAT_NAME", "="):
|
||||
# The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL,
|
||||
# so we parse it separately to use _parse_field()
|
||||
prop = self.expression(
|
||||
exp.Property, this=exp.var("FORMAT_NAME"), value=self._parse_field()
|
||||
)
|
||||
opts.append(prop)
|
||||
else:
|
||||
opts.append(self._parse_property())
|
||||
|
||||
self._match(TokenType.COMMA)
|
||||
|
||||
return opts
|
||||
|
||||
def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
|
||||
|
@ -6608,37 +6681,38 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
options = []
|
||||
while self._curr and not self._match(TokenType.R_PAREN, advance=False):
|
||||
option = self._parse_unquoted_field()
|
||||
value = None
|
||||
option = self._parse_var(any_token=True)
|
||||
prev = self._prev.text.upper()
|
||||
|
||||
# Some options are defined as functions with the values as params
|
||||
if not isinstance(option, exp.Func):
|
||||
prev = self._prev.text.upper()
|
||||
# Different dialects might separate options and values by white space, "=" and "AS"
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
# Different dialects might separate options and values by white space, "=" and "AS"
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN):
|
||||
# Snowflake FILE_FORMAT case
|
||||
value = self._parse_wrapped_options()
|
||||
else:
|
||||
value = self._parse_unquoted_field()
|
||||
param = self.expression(exp.CopyParameter, this=option)
|
||||
|
||||
if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match(
|
||||
TokenType.L_PAREN, advance=False
|
||||
):
|
||||
# Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options
|
||||
param.set("expressions", self._parse_wrapped_options())
|
||||
elif prev == "FILE_FORMAT":
|
||||
# T-SQL's external file format case
|
||||
param.set("expression", self._parse_field())
|
||||
else:
|
||||
param.set("expression", self._parse_unquoted_field())
|
||||
|
||||
param = self.expression(exp.CopyParameter, this=option, expression=value)
|
||||
options.append(param)
|
||||
|
||||
if sep:
|
||||
self._match(sep)
|
||||
self._match(sep)
|
||||
|
||||
return options
|
||||
|
||||
def _parse_credentials(self) -> t.Optional[exp.Credentials]:
|
||||
expr = self.expression(exp.Credentials)
|
||||
|
||||
if self._match_text_seq("STORAGE_INTEGRATION", advance=False):
|
||||
expr.set("storage", self._parse_conjunction())
|
||||
if self._match_text_seq("STORAGE_INTEGRATION", "="):
|
||||
expr.set("storage", self._parse_field())
|
||||
if self._match_text_seq("CREDENTIALS"):
|
||||
# Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string>
|
||||
# Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS <string>
|
||||
creds = (
|
||||
self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field()
|
||||
)
|
||||
|
@ -6661,7 +6735,7 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.INTO)
|
||||
|
||||
this = (
|
||||
self._parse_conjunction()
|
||||
self._parse_select(nested=True, parse_subquery_alias=False)
|
||||
if self._match(TokenType.L_PAREN, advance=False)
|
||||
else self._parse_table(schema=True)
|
||||
)
|
||||
|
|
|
@ -155,13 +155,16 @@ class AbstractMappingSchema:
|
|||
return [table.this.name]
|
||||
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
|
||||
|
||||
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
|
||||
def find(
|
||||
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
|
||||
) -> t.Optional[t.Any]:
|
||||
"""
|
||||
Returns the schema of a given table.
|
||||
|
||||
Args:
|
||||
table: the target table.
|
||||
raise_on_missing: whether to raise in case the schema is not found.
|
||||
ensure_data_types: whether to convert `str` types to their `DataType` equivalents.
|
||||
|
||||
Returns:
|
||||
The schema of the target table.
|
||||
|
@ -239,6 +242,20 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
normalize=mapping_schema.normalize,
|
||||
)
|
||||
|
||||
def find(
|
||||
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
|
||||
) -> t.Optional[t.Any]:
|
||||
schema = super().find(
|
||||
table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types
|
||||
)
|
||||
if ensure_data_types and isinstance(schema, dict):
|
||||
schema = {
|
||||
col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype
|
||||
for col, dtype in schema.items()
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
def copy(self, **kwargs) -> MappingSchema:
|
||||
return MappingSchema(
|
||||
**{ # type: ignore
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue