1
0
Fork 0

Merging upstream version 23.16.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:35:32 +01:00
parent d0f42f708a
commit 213191b8e3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
93 changed files with 64106 additions and 59061 deletions

View file

@ -13,6 +13,8 @@ class Athena(Trino):
}
class Generator(Trino.Generator):
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"
PROPERTIES_LOCATION = {
**Trino.Generator.PROPERTIES_LOCATION,
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
@ -32,6 +34,3 @@ class Athena(Trino):
return (
f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
)
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))

View file

@ -156,7 +156,7 @@ def _build_date(args: t.List) -> exp.Date | exp.DateFromParts:
def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5:
# TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
arg = seq_get(args, 0)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.LowerHex(this=arg)
def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str:
@ -212,6 +212,7 @@ class BigQuery(Dialect):
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
LOG_BASE_FIRST = False
HEX_LOWERCASE = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
@ -568,6 +569,8 @@ class BigQuery(Dialect):
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
NAMED_PLACEHOLDER_TOKEN = "@"
HEX_FUNC = "TO_HEX"
WITH_PROPERTIES_PREFIX = "OPTIONS"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -603,13 +606,12 @@ class BigQuery(Dialect):
),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
exp.Mod: rename_func("MOD"),
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
@ -634,6 +636,7 @@ class BigQuery(Dialect):
transforms.eliminate_semi_and_anti_joins,
]
),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
@ -801,6 +804,15 @@ class BigQuery(Dialect):
"within",
}
def mod_sql(self, expression: exp.Mod) -> str:
this = expression.this
expr = expression.expression
return self.func(
"MOD",
this.unnest() if isinstance(this, exp.Paren) else this,
expr.unnest() if isinstance(expr, exp.Paren) else expr,
)
def column_parts(self, expression: exp.Column) -> str:
if expression.meta.get("quoted_column"):
# If a column reference is of the form `dataset.table`.name, we need
@ -896,9 +908,6 @@ class BigQuery(Dialect):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))
def version_sql(self, expression: exp.Version) -> str:
if expression.name == "TIMESTAMP":
expression.set("this", "SYSTEM_TIME")

View file

@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
build_json_extract_path,
rename_func,
var_map_sql,
timestamptrunc_sql,
)
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
@ -30,6 +31,27 @@ def _build_date_format(args: t.List) -> exp.TimeToStr:
return expr
def _unix_to_time_sql(self: ClickHouse.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = expression.this
if scale in (None, exp.UnixToTime.SECONDS):
return self.func("fromUnixTimestamp", exp.cast(timestamp, exp.DataType.Type.BIGINT))
if scale == exp.UnixToTime.MILLIS:
return self.func("fromUnixTimestamp64Milli", exp.cast(timestamp, exp.DataType.Type.BIGINT))
if scale == exp.UnixToTime.MICROS:
return self.func("fromUnixTimestamp64Micro", exp.cast(timestamp, exp.DataType.Type.BIGINT))
if scale == exp.UnixToTime.NANOS:
return self.func("fromUnixTimestamp64Nano", exp.cast(timestamp, exp.DataType.Type.BIGINT))
return self.func(
"fromUnixTimestamp",
exp.cast(
exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT
),
)
def _lower_func(sql: str) -> str:
index = sql.index("(")
return sql[:index].lower() + sql[index:]
@ -146,6 +168,9 @@ class ClickHouse(Dialect):
"TUPLE": exp.Struct.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
"XOR": lambda args: exp.Xor(expressions=args),
"MD5": exp.MD5Digest.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
}
AGG_FUNCTIONS = {
@ -353,6 +378,11 @@ class ClickHouse(Dialect):
"CODEC": lambda self: self._parse_compress(),
}
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
"REPLACE": lambda self: self._parse_alter_table_replace(),
}
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"INDEX",
@ -578,6 +608,44 @@ class ClickHouse(Dialect):
granularity=granularity,
)
def _parse_partition(self) -> t.Optional[exp.Partition]:
# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#how-to-set-partition-expression
if not self._match(TokenType.PARTITION):
return None
if self._match_text_seq("ID"):
# Corresponds to the PARTITION ID <string_value> syntax
expressions: t.List[exp.Expression] = [
self.expression(exp.PartitionId, this=self._parse_string())
]
else:
expressions = self._parse_expressions()
return self.expression(exp.Partition, expressions=expressions)
def _parse_alter_table_replace(self) -> t.Optional[exp.Expression]:
partition = self._parse_partition()
if not partition or not self._match(TokenType.FROM):
return None
return self.expression(
exp.ReplacePartition, expression=partition, source=self._parse_table_parts()
)
def _parse_projection_def(self) -> t.Optional[exp.ProjectionDef]:
if not self._match_text_seq("PROJECTION"):
return None
return self.expression(
exp.ProjectionDef,
this=self._parse_id_var(),
expression=self._parse_wrapped(self._parse_statement),
)
def _parse_constraint(self) -> t.Optional[exp.Expression]:
return super()._parse_constraint() or self._parse_projection_def()
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
@ -687,6 +755,16 @@ class ClickHouse(Dialect):
),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
exp.MD5Digest: rename_func("MD5"),
exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Variance: rename_func("varSamp"),
exp.Stddev: rename_func("stddevSamp"),
}
PROPERTIES_LOCATION = {
@ -828,3 +906,17 @@ class ClickHouse(Dialect):
granularity = f" GRANULARITY {granularity}" if granularity else ""
return f"INDEX{this}{expr}{index_type}{granularity}"
def partition_sql(self, expression: exp.Partition) -> str:
return f"PARTITION {self.expressions(expression, flat=True)}"
def partitionid_sql(self, expression: exp.PartitionId) -> str:
return f"ID {self.sql(expression.this)}"
def replacepartition_sql(self, expression: exp.ReplacePartition) -> str:
return (
f"REPLACE {self.sql(expression.expression)} FROM {self.sql(expression, 'source')}"
)
def projectiondef_sql(self, expression: exp.ProjectionDef) -> str:
return f"PROJECTION {self.sql(expression.this)} {self.wrap(expression.expression)}"

View file

@ -57,7 +57,7 @@ class Databricks(Spark):
),
exp.DatetimeDiff: _timestamp_diff,
exp.TimestampDiff: _timestamp_diff,
exp.DatetimeTrunc: timestamptrunc_sql,
exp.DatetimeTrunc: timestamptrunc_sql(),
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[

View file

@ -248,6 +248,9 @@ class Dialect(metaclass=_Dialect):
CONCAT_COALESCE = False
"""A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
HEX_LOWERCASE = False
"""Whether the `HEX` function returns a lowercase hexadecimal string."""
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
@ -769,8 +772,14 @@ def date_add_interval_sql(
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]:
def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
args = [unit_to_str(expression), expression.this]
if zone:
args.append(expression.args.get("zone"))
return self.func("DATE_TRUNC", *args)
return _timestamptrunc_sql
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@ -1141,10 +1150,23 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
return self.sql(exp.Array(expressions=[filtered]))
def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("nlsparam"),
)
def build_default_decimal_type(
precision: t.Optional[int] = None, scale: t.Optional[int] = None
) -> t.Callable[[exp.DataType], exp.DataType]:
def _builder(dtype: exp.DataType) -> exp.DataType:
if dtype.expressions or precision is None:
return dtype
params = f"{precision}{f', {scale}' if scale is not None else ''}"
return exp.DataType.build(f"DECIMAL({params})")
return _builder

View file

@ -5,12 +5,14 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
JSON_EXTRACT_TYPE,
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
arrow_json_extract_sql,
binary_from_function,
bool_xor_sql,
build_default_decimal_type,
date_trunc_to_time,
datestrtodate_sql,
encode_decode_sql,
@ -155,6 +157,13 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)))
def _arrow_json_extract_sql(self: DuckDB.Generator, expression: JSON_EXTRACT_TYPE) -> str:
arrow_sql = arrow_json_extract_sql(self, expression)
if not expression.same_parent and isinstance(expression.parent, exp.Binary):
arrow_sql = self.wrap(arrow_sql)
return arrow_sql
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
@ -304,6 +313,22 @@ class DuckDB(Dialect):
),
}
TYPE_CONVERTER = {
# https://duckdb.org/docs/sql/data_types/numeric
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3),
}
def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
# https://duckdb.org/docs/sql/samples.html
sample = super()._parse_table_sample(as_modifier=as_modifier)
if sample and not sample.args.get("method"):
if sample.args.get("size"):
sample.set("method", exp.var("RESERVOIR"))
else:
sample.set("method", exp.var("SYSTEM"))
return sample
def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
@ -320,24 +345,6 @@ class DuckDB(Dialect):
args = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
this = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
# See: https://duckdb.org/docs/sql/data_types/numeric
if (
isinstance(this, exp.DataType)
and this.is_type("numeric", "decimal")
and not this.expressions
):
return exp.DataType.build("DECIMAL(18, 3)")
return this
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
return self._parse_field_def()
@ -368,6 +375,7 @@ class DuckDB(Dialect):
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
COPY_HAS_INTO_KEYWORD = False
STAR_EXCEPT = "EXCLUDE"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -406,11 +414,12 @@ class DuckDB(Dialect):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: _arrow_json_extract_sql,
exp.JSONExtractScalar: _arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
@ -449,7 +458,7 @@ class DuckDB(Dialect):
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: self.func(
@ -499,8 +508,6 @@ class DuckDB(Dialect):
exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
}
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties
@ -550,6 +557,15 @@ class DuckDB(Dialect):
# This sample clause only applies to a single source, not the entire resulting relation
tablesample_keyword = "TABLESAMPLE"
if expression.args.get("size"):
method = expression.args.get("method")
if method and method.name.upper() != "RESERVOIR":
self.unsupported(
f"Sampling method {method} is not supported with a discrete sample count, "
"defaulting to reservoir sampling"
)
expression.set("method", exp.var("RESERVOIR"))
return super().tablesample_sql(
expression, sep=sep, tablesample_keyword=tablesample_keyword
)

View file

@ -254,7 +254,7 @@ class Hive(Dialect):
"REFRESH": TokenType.REFRESH,
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
"VERSION AS OF": TokenType.VERSION_SNAPSHOT,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
"SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
NUMERIC_LITERALS = {
@ -332,7 +332,7 @@ class Hive(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
"SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}
@ -422,6 +422,15 @@ class Hive(Dialect):
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)
def _parse_parameter(self) -> exp.Parameter:
self._match(TokenType.L_BRACE)
this = self._parse_identifier() or self._parse_primary_or_var()
expression = self._match(TokenType.COLON) and (
self._parse_identifier() or self._parse_primary_or_var()
)
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, expression=expression)
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@ -434,6 +443,7 @@ class Hive(Dialect):
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
SUPPORTS_TO_NUMBER = False
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@ -453,11 +463,12 @@ class Hive(Dialect):
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIT: "BOOLEAN",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.ROWVERSION: "BINARY",
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIME: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.UTINYINT: "SMALLINT",
exp.DataType.Type.VARBINARY: "BINARY",
exp.DataType.Type.ROWVERSION: "BINARY",
}
TRANSFORMS = {
@ -552,7 +563,6 @@ class Hive(Dialect):
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
exp.ClusteredColumnConstraint: lambda self,
@ -618,9 +628,6 @@ class Hive(Dialect):
expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
)
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
def datatype_sql(self, expression: exp.DataType) -> str:
if expression.this in self.PARAMETERIZABLE_TEXT_TYPES and (
not expression.expressions or expression.expressions[0].name == "MAX"
@ -655,3 +662,23 @@ class Hive(Dialect):
values.append(e)
return self.func("STRUCT", *values)
def alterset_sql(self, expression: exp.AlterSet) -> str:
exprs = self.expressions(expression, flat=True)
exprs = f" {exprs}" if exprs else ""
location = self.sql(expression, "location")
location = f" LOCATION {location}" if location else ""
file_format = self.expressions(expression, key="file_format", flat=True, sep=" ")
file_format = f" FILEFORMAT {file_format}" if file_format else ""
serde = self.sql(expression, "serde")
serde = f" SERDE {serde}" if serde else ""
tags = self.expressions(expression, key="tag", flat=True, sep="")
tags = f" TAGS {tags}" if tags else ""
return f"SET{serde}{exprs}{location}{file_format}{tags}"
def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str:
prefix = "WITH " if expression.args.get("with") else ""
exprs = self.expressions(expression, flat=True)
return f"{prefix}SERDEPROPERTIES ({exprs})"

View file

@ -110,6 +110,20 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _unix_to_time_sql(self: MySQL.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = expression.this
if scale in (None, exp.UnixToTime.SECONDS):
return self.func("FROM_UNIXTIME", timestamp, self.format_time(expression))
return self.func(
"FROM_UNIXTIME",
exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)),
self.format_time(expression),
)
def date_add_sql(
kind: str,
) -> t.Callable[[generator.Generator, exp.Expression], str]:
@ -251,7 +265,7 @@ class MySQL(Dialect):
"@@": TokenType.SESSION_PARAMETER,
}
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.REPLACE} - {TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {
@ -723,7 +737,7 @@ class MySQL(Dialect):
exp.TsOrDsAdd: date_add_sql("ADD"),
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.UnixToTime: _unix_to_time_sql,
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
exp.Year: _remove_ts_or_ds_to_date(),
@ -805,6 +819,9 @@ class MySQL(Dialect):
exp.DataType.Type.TIMESTAMPLTZ,
}
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.func("CONCAT", *expression.flatten())
def extract_sql(self, expression: exp.Extract) -> str:
unit = expression.name
if unit and unit.lower() == "epoch":

View file

@ -227,12 +227,27 @@ def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
return exp.RegexpReplace.from_arg_list(args)
def _unix_to_time_sql(self: Postgres.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = expression.this
if scale in (None, exp.UnixToTime.SECONDS):
return self.func("TO_TIMESTAMP", timestamp, self.format_time(expression))
return self.func(
"TO_TIMESTAMP",
exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)),
self.format_time(expression),
)
class Postgres(Dialect):
INDEX_OFFSET = 1
TYPED_DIVISION = True
CONCAT_COALESCE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
TABLESAMPLE_SIZE_IS_PERCENT = True
TIME_MAPPING = {
"AM": "%p",
@ -528,7 +543,7 @@ class Postgres(Dialect):
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
@ -543,6 +558,7 @@ class Postgres(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
exp.UnixToTime: _unix_to_time_sql,
}
TRANSFORMS.pop(exp.CommentColumnConstraint)
@ -593,3 +609,15 @@ class Postgres(Dialect):
expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions]
sql = " OR ".join(expressions)
return f"({sql})" if len(expressions) > 1 else sql
def alterset_sql(self, expression: exp.AlterSet) -> str:
exprs = self.expressions(expression, flat=True)
exprs = f"({exprs})" if exprs else ""
access_method = self.sql(expression, "access_method")
access_method = f"ACCESS METHOD {access_method}" if access_method else ""
tablespace = self.sql(expression, "tablespace")
tablespace = f"TABLESPACE {tablespace}" if tablespace else ""
option = self.sql(expression, "option")
return f"SET {exprs}{access_method}{tablespace}{option}"

View file

@ -32,6 +32,7 @@ from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import unqualify_columns
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
@ -276,11 +277,13 @@ class Presto(Dialect):
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
"TO_CHAR": _build_to_char,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"MD5": exp.MD5Digest.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
@ -300,6 +303,7 @@ class Presto(Dialect):
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
SUPPORTS_TO_NUMBER = False
HEX_FUNC = "TO_HEX"
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -381,7 +385,6 @@ class Presto(Dialect):
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@ -417,7 +420,7 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: self.func(
@ -444,6 +447,14 @@ class Presto(Dialect):
[transforms.remove_within_group_for_percentiles]
),
exp.Xor: bool_xor_sql,
exp.MD5: lambda self, e: self.func(
"LOWER", self.func("TO_HEX", self.func("MD5", self.sql(e, "this")))
),
exp.MD5Digest: rename_func("MD5"),
exp.SHA: rename_func("SHA1"),
exp.SHA2: lambda self, e: self.func(
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
}
RESERVED_KEYWORDS = {
@ -615,3 +626,25 @@ class Presto(Dialect):
if kind == "VIEW" and schema.expressions:
expression.this.set("expressions", None)
return super().create_sql(expression)
def delete_sql(self, expression: exp.Delete) -> str:
"""
Presto only supports DELETE FROM for a single table without an alias, so we need
to remove the unnecessary parts. If the original DELETE statement contains more
than one table to be deleted, we can't safely map it 1-1 to a Presto statement.
"""
tables = expression.args.get("tables") or [expression.this]
if len(tables) > 1:
return super().delete_sql(expression)
table = tables[0]
expression.set("this", table)
expression.set("tables", None)
if isinstance(table, exp.Table):
table_alias = table.args.get("alias")
if table_alias:
table_alias.pop()
expression = t.cast(exp.Delete, expression.transform(unqualify_columns))
return super().delete_sql(expression)

View file

@ -39,6 +39,7 @@ class Redshift(Postgres):
SUPPORTS_USER_DEFINED_TYPES = False
INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
@ -140,6 +141,9 @@ class Redshift(Postgres):
CAN_IMPLEMENT_ARRAY_ANY = False
MULTI_ARG_DISTINCT = True
COPY_PARAMS_ARE_WRAPPED = False
HEX_FUNC = "TO_HEX"
# Redshift doesn't have `WITH` as part of their with_properties so we remove it
WITH_PROPERTIES_PREFIX = " "
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@ -169,6 +173,7 @@ class Redshift(Postgres):
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
[
@ -372,10 +377,6 @@ class Redshift(Postgres):
alias = self.expressions(expression.args.get("alias"), key="columns", flat=True)
return f"{arg} AS {alias}" if alias else arg
def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type(exp.DataType.Type.JSON):
# Redshift doesn't support a JSON type, so casting to it is treated as a noop
@ -398,3 +399,13 @@ class Redshift(Postgres):
expression.append("expressions", exp.var("MAX"))
return super().datatype_sql(expression)
def alterset_sql(self, expression: exp.AlterSet) -> str:
exprs = self.expressions(expression, flat=True)
exprs = f" TABLE PROPERTIES ({exprs})" if exprs else ""
location = self.sql(expression, "location")
location = f" LOCATION {location}" if location else ""
file_format = self.expressions(expression, key="file_format", flat=True, sep=" ")
file_format = f" FILE FORMAT {file_format}" if file_format else ""
return f"SET{exprs}{location}{file_format}"

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
binary_from_function,
build_default_decimal_type,
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
@ -334,6 +335,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
DEFAULT_SAMPLING_METHOD = "BERNOULLI"
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS,
@ -345,6 +347,7 @@ class Snowflake(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_CONTAINS": lambda args: exp.ArrayContains(
@ -423,7 +426,6 @@ class Snowflake(Dialect):
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS,
"SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")),
"UNSET": lambda self: self.expression(
exp.Set,
tag=self._match_text_seq("TAG"),
@ -443,6 +445,11 @@ class Snowflake(Dialect):
"LOCATION": lambda self: self._parse_location_property(),
}
TYPE_CONVERTER = {
# https://docs.snowflake.com/en/sql-reference/data-types-numeric#number
exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0),
}
SHOW_PARSERS = {
"SCHEMAS": _show_parser("SCHEMAS"),
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
@ -475,6 +482,14 @@ class Snowflake(Dialect):
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
def _parse_create(self) -> exp.Create | exp.Command:
expression = super()._parse_create()
if isinstance(expression, exp.Create) and expression.kind == "TAG":
# Replace the Table node with the enclosed Identifier
expression.this.replace(expression.this.this)
return expression
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = super()._parse_column_ops(this)
@ -600,8 +615,8 @@ class Snowflake(Dialect):
file_format = None
pattern = None
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
wrapped = self._match(TokenType.L_PAREN)
while self._curr and wrapped and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts(
is_db_reference=is_db_reference
@ -681,14 +696,22 @@ class Snowflake(Dialect):
return self.expression(exp.LocationProperty, this=self._parse_location_path())
def _parse_file_location(self) -> t.Optional[exp.Expression]:
return self._parse_table_parts()
# Parse either a subquery or a staged file
return (
self._parse_select(table=True)
if self._match(TokenType.L_PAREN, advance=False)
else self._parse_table_parts()
)
def _parse_location_path(self) -> exp.Var:
parts = [self._advance_any(ignore_reserved=True)]
# We avoid consuming a comma token because external tables like @foo and @bar
# can be joined in a query with a comma separator.
while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
# can be joined in a query with a comma separator, as well as closing paren
# in case of subqueries
while self._is_connected() and not self._match_set(
(TokenType.COMMA, TokenType.R_PAREN), advance=False
):
parts.append(self._advance_any(ignore_reserved=True))
return exp.var("".join(part.text for part in parts if part))
@ -713,12 +736,12 @@ class Snowflake(Dialect):
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"REMOVE": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"RM": TokenType.COMMAND,
"SAMPLE": TokenType.TABLE_SAMPLE,
"SQL_DOUBLE": TokenType.DOUBLE,
"SQL_VARCHAR": TokenType.VARCHAR,
"STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION,
"TAG": TokenType.TAG,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
}
@ -748,6 +771,7 @@ class Snowflake(Dialect):
STRUCT_DELIMITER = ("(", ")")
COPY_PARAMS_ARE_WRAPPED = False
COPY_PARAMS_EQ_REQUIRED = True
STAR_EXCEPT = "EXCLUDE"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -818,7 +842,7 @@ class Snowflake(Dialect):
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimestampTrunc: timestamptrunc_sql(),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
"TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e)
@ -850,11 +874,6 @@ class Snowflake(Dialect):
exp.DataType.Type.STRUCT: "OBJECT",
}
STAR_MAPPING = {
"except": "EXCLUDE",
"replace": "RENAME",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
@ -862,9 +881,15 @@ class Snowflake(Dialect):
}
UNSUPPORTED_VALUES_EXPRESSIONS = {
exp.Map,
exp.StarMap,
exp.Struct,
exp.VarMap,
}
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.sep(""), sep=" ")
def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
values_as_table = False
@ -1019,9 +1044,6 @@ class Snowflake(Dialect):
this = self.sql(expression, "this")
return f"SWAP WITH {this}"
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
def cluster_sql(self, expression: exp.Cluster) -> str:
return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
@ -1041,10 +1063,22 @@ class Snowflake(Dialect):
return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
option = self.sql(expression, "this").upper()
if option == "FILE_FORMAT":
values = self.expressions(expression, key="expression", flat=True, sep=" ")
return f"{option} = ({values})"
def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str:
if expression.args.get("weight") or expression.args.get("accuracy"):
self.unsupported(
"APPROX_PERCENTILE with weight and/or accuracy arguments are not supported in Snowflake"
)
return super().copyparameter_sql(expression)
return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile"))
def alterset_sql(self, expression: exp.AlterSet) -> str:
exprs = self.expressions(expression, flat=True)
exprs = f" {exprs}" if exprs else ""
file_format = self.expressions(expression, key="file_format", flat=True, sep=" ")
file_format = f" STAGE_FILE_FORMAT = ({file_format})" if file_format else ""
copy_options = self.expressions(expression, key="copy_options", flat=True, sep=" ")
copy_options = f" STAGE_COPY_OPTIONS = ({copy_options})" if copy_options else ""
tag = self.expressions(expression, key="tag", flat=True)
tag = f" TAG {tag}" if tag else ""
return f"SET{exprs}{file_format}{copy_options}{tag}"

View file

@ -96,6 +96,9 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.ArrayConstructCompact: lambda self, e: self.func(
"ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
),
exp.Create: preprocess(
[
remove_unique_constraints,

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import datetime
import re
import typing as t
from functools import partial
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
@ -211,7 +212,8 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
if isinstance(expression.this, exp.Order):
if expression.this.this:
this = expression.this.this.pop()
order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
# Order has a leading space
order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})"
separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
@ -451,11 +453,14 @@ class TSQL(Dialect):
**tokens.Tokenizer.KEYWORDS,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
"DECLARE": TokenType.DECLARE,
"EXEC": TokenType.COMMAND,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
"OPTION": TokenType.OPTION,
"OUTPUT": TokenType.RETURNING,
"PRINT": TokenType.COMMAND,
"PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
@ -463,17 +468,17 @@ class TSQL(Dialect):
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
"SYSTEM_USER": TokenType.CURRENT_USER,
"TOP": TokenType.TOP,
"TIMESTAMP": TokenType.ROWVERSION,
"TINYINT": TokenType.UTINYINT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"UPDATE STATISTICS": TokenType.COMMAND,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"OPTION": TokenType.OPTION,
}
COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END}
class Parser(parser.Parser):
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
LOG_DEFAULTS_TO_LN = True
@ -526,7 +531,7 @@ class TSQL(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
TokenType.DECLARE: lambda self: self._parse_declare(),
}
def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
@ -711,6 +716,32 @@ class TSQL(Dialect):
return partition
def _parse_declare(self) -> exp.Declare | exp.Command:
index = self._index
expressions = self._try_parse(partial(self._parse_csv, self._parse_declareitem))
if not expressions or self._curr:
self._retreat(index)
return self._parse_as_command(self._prev)
return self.expression(exp.Declare, expressions=expressions)
def _parse_declareitem(self) -> t.Optional[exp.DeclareItem]:
var = self._parse_id_var()
if not var:
return None
value = None
self._match(TokenType.ALIAS)
if self._match(TokenType.TABLE):
data_type = self._parse_schema()
else:
data_type = self._parse_types()
if self._match(TokenType.EQ):
value = self._parse_bitwise()
return self.expression(exp.DeclareItem, this=var, kind=data_type, default=value)
class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
@ -753,11 +784,12 @@ class TSQL(Dialect):
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.DOUBLE: "FLOAT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.ROWVERSION: "ROWVERSION",
exp.DataType.Type.TEXT: "VARCHAR(MAX)",
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.UTINYINT: "TINYINT",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
exp.DataType.Type.ROWVERSION: "ROWVERSION",
}
TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
@ -920,6 +952,10 @@ class TSQL(Dialect):
def create_sql(self, expression: exp.Create) -> str:
kind = expression.kind
exists = expression.args.pop("exists", None)
if kind == "VIEW":
expression.this.set("catalog", None)
sql = super().create_sql(expression)
like_property = expression.find(exp.LikeProperty)
@ -1061,3 +1097,22 @@ class TSQL(Dialect):
if isinstance(action, exp.RenameTable):
return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'"
return super().altertable_sql(expression)
def drop_sql(self, expression: exp.Drop) -> str:
if expression.args["kind"] == "VIEW":
expression.this.set("catalog", None)
return super().drop_sql(expression)
def declare_sql(self, expression: exp.Declare) -> str:
return f"DECLARE {self.expressions(expression, flat=True)}"
def declareitem_sql(self, expression: exp.DeclareItem) -> str:
variable = self.sql(expression, "this")
default = self.sql(expression, "default")
default = f" = {default}" if default else ""
kind = self.sql(expression, "kind")
if isinstance(expression.args.get("kind"), exp.Schema):
kind = f"TABLE {kind}"
return f"{variable} AS {kind}{default}"

View file

@ -240,5 +240,7 @@ ENV = {
for x in range(0, len(args), 2)
if (args[x + 1] is not None and args[x] is not None)
},
"UNIXTOTIME": null_if_any(lambda arg: datetime.datetime.utcfromtimestamp(arg)),
"UNIXTOTIME": null_if_any(
lambda arg: datetime.datetime.fromtimestamp(arg, datetime.timezone.utc)
),
}

View file

@ -1445,6 +1445,14 @@ class Pragma(Expression):
pass
class Declare(Expression):
arg_types = {"expressions": True}
class DeclareItem(Expression):
arg_types = {"this": True, "kind": True, "default": False}
class Set(Expression):
arg_types = {"expressions": False, "unset": False, "tag": False}
@ -1520,6 +1528,10 @@ class CTE(DerivedTable):
}
class ProjectionDef(Expression):
arg_types = {"this": True, "expression": True}
class TableAlias(Expression):
arg_types = {"this": False, "columns": False}
@ -1623,6 +1635,29 @@ class AlterColumn(Expression):
}
# https://docs.aws.amazon.com/redshift/latest/dg/r_ALTER_TABLE.html
class AlterDistStyle(Expression):
pass
class AlterSortKey(Expression):
arg_types = {"this": False, "expressions": False, "compound": False}
class AlterSet(Expression):
arg_types = {
"expressions": False,
"option": False,
"tablespace": False,
"access_method": False,
"file_format": False,
"copy_options": False,
"tag": False,
"location": False,
"serde": False,
}
class RenameColumn(Expression):
arg_types = {"this": True, "to": True, "exists": False}
@ -1939,6 +1974,7 @@ class Drop(Expression):
"cascade": False,
"constraints": False,
"purge": False,
"cluster": False,
}
@ -2177,6 +2213,11 @@ class PartitionRange(Expression):
arg_types = {"this": True, "expression": True}
# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#how-to-set-partition-expression
class PartitionId(Expression):
pass
class Fetch(Expression):
arg_types = {
"direction": False,
@ -2422,6 +2463,10 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
class AllowedValuesProperty(Expression):
arg_types = {"expressions": True}
class AlgorithmProperty(Property):
arg_types = {"this": True}
@ -2475,6 +2520,10 @@ class DataBlocksizeProperty(Property):
}
class DataDeletionProperty(Property):
arg_types = {"on": True, "filter_col": False, "retention_period": False}
class DefinerProperty(Property):
arg_types = {"this": True}
@ -2651,7 +2700,11 @@ class RemoteWithConnectionModelProperty(Property):
class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
arg_types = {"this": False, "is_table": False, "table": False, "null": False}
class StrictProperty(Property):
arg_types = {}
class RowFormatProperty(Property):
@ -2697,7 +2750,7 @@ class SchemaCommentProperty(Property):
class SerdeProperties(Property):
arg_types = {"expressions": True}
arg_types = {"expressions": True, "with": False}
class SetProperty(Property):
@ -2766,8 +2819,13 @@ class WithJournalTableProperty(Property):
class WithSystemVersioningProperty(Property):
# this -> history table name, expression -> data consistency check
arg_types = {"this": False, "expression": False}
arg_types = {
"on": False,
"this": False,
"data_consistency": False,
"retention_period": False,
"with": True,
}
class Properties(Expression):
@ -3801,7 +3859,7 @@ class Where(Expression):
class Star(Expression):
arg_types = {"except": False, "replace": False}
arg_types = {"except": False, "replace": False, "rename": False}
@property
def name(self) -> str:
@ -4175,6 +4233,7 @@ class AlterTable(Expression):
"exists": False,
"only": False,
"options": False,
"cluster": False,
}
@ -4186,6 +4245,11 @@ class DropPartition(Expression):
arg_types = {"expressions": True, "exists": False}
# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#replace-partition
class ReplacePartition(Expression):
arg_types = {"expression": True, "source": True}
# Binary expressions like (ADD a b)
class Binary(Condition):
arg_types = {"this": True, "expression": True}
@ -4738,6 +4802,11 @@ class ArrayConcat(Func):
is_var_len_args = True
class ArrayConstructCompact(Func):
arg_types = {"expressions": True}
is_var_len_args = True
class ArrayContains(Binary, Func):
pass
@ -5172,6 +5241,10 @@ class Hex(Func):
pass
class LowerHex(Hex):
pass
class Xor(Connector, Func):
arg_types = {"this": False, "expression": False, "expressions": False}
@ -5902,6 +5975,12 @@ class NextValueFor(Func):
arg_types = {"this": True, "order": False}
# Refers to a trailing semi-colon. This is only used to preserve trailing comments
# select 1; -- my comment
class Semicolon(Expression):
arg_types = {}
def _norm_arg(arg):
return arg.lower() if type(arg) is str else arg

View file

@ -8,7 +8,7 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.helper import apply_index_offset, csv, name_sequence, seq_get
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@ -74,6 +74,8 @@ class Generator(metaclass=_Generator):
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AllowedValuesProperty: lambda self,
e: f"ALLOWED_VALUES {self.expressions(e, flat=True)}",
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda _,
@ -123,7 +125,9 @@ class Generator(metaclass=_Generator):
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: (
"RETURNS NULL ON NULL INPUT" if e.args.get("null") else self.naked_property(e)
),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
@ -133,6 +137,7 @@ class Generator(metaclass=_Generator):
exp.SqlSecurityProperty: lambda _,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda _, e: e.name,
exp.StrictProperty: lambda *_: "STRICT",
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
@ -351,6 +356,15 @@ class Generator(metaclass=_Generator):
# Whether the conditional TRY(expression) function is supported
TRY_SUPPORTED = True
# The keyword to use when generating a star projection with excluded columns
STAR_EXCEPT = "EXCEPT"
# The HEX function name
HEX_FUNC = "HEX"
# The keywords to use when prefixing & separating WITH based properties
WITH_PROPERTIES_PREFIX = "WITH"
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -364,11 +378,6 @@ class Generator(metaclass=_Generator):
exp.DataType.Type.ROWVERSION: "VARBINARY",
}
STAR_MAPPING = {
"except": "EXCEPT",
"replace": "REPLACE",
}
TIME_PART_SINGULARS = {
"MICROSECONDS": "MICROSECOND",
"SECONDS": "SECOND",
@ -401,6 +410,7 @@ class Generator(metaclass=_Generator):
NAMED_PLACEHOLDER_TOKEN = ":"
PROPERTIES_LOCATION = {
exp.AllowedValuesProperty: exp.Properties.Location.POST_SCHEMA,
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
@ -413,6 +423,7 @@ class Generator(metaclass=_Generator):
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DataDeletionProperty: exp.Properties.Location.POST_SCHEMA,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
@ -466,6 +477,7 @@ class Generator(metaclass=_Generator):
exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.StrictProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
@ -539,6 +551,7 @@ class Generator(metaclass=_Generator):
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_next_name",
)
def __init__(
@ -584,6 +597,8 @@ class Generator(metaclass=_Generator):
self.dialect.tokenizer_class.IDENTIFIER_ESCAPES[0] + self.dialect.IDENTIFIER_END
)
self._next_name = name_sequence("_t")
def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
@ -687,15 +702,15 @@ class Generator(metaclass=_Generator):
return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
(
self.sql(expression)
if isinstance(expression, exp.UNWRAPPED_QUERIES)
else self.sql(expression, "this")
),
level=1,
pad=0,
this_sql = (
self.sql(expression)
if isinstance(expression, exp.UNWRAPPED_QUERIES)
else self.sql(expression, "this")
)
if not this_sql:
return "()"
this_sql = self.indent(this_sql, level=1, pad=0)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
@ -720,7 +735,7 @@ class Generator(metaclass=_Generator):
skip_first: bool = False,
skip_last: bool = False,
) -> str:
if not self.pretty:
if not self.pretty or not sql:
return sql
pad = self.pad if pad is None else pad
@ -951,6 +966,12 @@ class Generator(metaclass=_Generator):
)
)
if properties_locs.get(exp.Properties.Location.POST_SCHEMA):
properties_sql = self.sep() + properties_sql
elif not self.pretty:
# Standalone POST_WITH properties need a leading whitespace in non-pretty mode
properties_sql = f" {properties_sql}"
begin = " BEGIN" if expression.args.get("begin") else ""
end = " END" if expression.args.get("end") else ""
@ -1095,7 +1116,7 @@ class Generator(metaclass=_Generator):
self.unsupported("Named columns are not supported in table alias.")
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
alias = "_t"
alias = self._next_name()
return f"{alias}{columns}"
@ -1208,12 +1229,14 @@ class Generator(metaclass=_Generator):
expressions = f" ({expressions})" if expressions else ""
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
on_cluster = self.sql(expression, "cluster")
on_cluster = f" {on_cluster}" if on_cluster else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
purge = " PURGE" if expression.args.get("purge") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{expressions}{cascade}{constraints}{purge}"
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{on_cluster}{expressions}{cascade}{constraints}{purge}"
def except_sql(self, expression: exp.Except) -> str:
return self.set_operations(expression)
@ -1296,6 +1319,19 @@ class Generator(metaclass=_Generator):
text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
return text
def hex_sql(self, expression: exp.Hex) -> str:
text = self.func(self.HEX_FUNC, self.sql(expression, "this"))
if self.dialect.HEX_LOWERCASE:
text = self.func("LOWER", text)
return text
def lowerhex_sql(self, expression: exp.LowerHex) -> str:
text = self.func(self.HEX_FUNC, self.sql(expression, "this"))
if not self.dialect.HEX_LOWERCASE:
text = self.func("LOWER", text)
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
input_format = self.sql(expression, "input_format")
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
@ -1321,13 +1357,17 @@ class Generator(metaclass=_Generator):
elif p_loc == exp.Properties.Location.POST_SCHEMA:
root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
root_props = self.root_properties(exp.Properties(expressions=root_properties))
with_props = self.with_properties(exp.Properties(expressions=with_properties))
if root_props and with_props and not self.pretty:
with_props = " " + with_props
return root_props + with_props
def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return self.expressions(properties, indent=False, sep=" ")
return ""
def properties(
@ -1346,7 +1386,7 @@ class Generator(metaclass=_Generator):
return ""
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
return self.properties(properties, prefix=self.seg(self.WITH_PROPERTIES_PREFIX, sep=""))
def locate_properties(self, properties: exp.Properties) -> t.DefaultDict:
properties_locs = defaultdict(list)
@ -1514,19 +1554,25 @@ class Generator(metaclass=_Generator):
return f"{data_sql}{statistics_sql}"
def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str:
sql = "WITH(SYSTEM_VERSIONING=ON"
this = self.sql(expression, "this")
this = f"HISTORY_TABLE={this}" if this else ""
data_consistency: t.Optional[str] = self.sql(expression, "data_consistency")
data_consistency = (
f"DATA_CONSISTENCY_CHECK={data_consistency}" if data_consistency else None
)
retention_period: t.Optional[str] = self.sql(expression, "retention_period")
retention_period = (
f"HISTORY_RETENTION_PERIOD={retention_period}" if retention_period else None
)
if expression.this:
history_table = self.sql(expression, "this")
sql = f"{sql}(HISTORY_TABLE={history_table}"
if this:
on_sql = self.func("ON", this, data_consistency, retention_period)
else:
on_sql = "ON" if expression.args.get("on") else "OFF"
if expression.expression:
data_consistency_check = self.sql(expression, "expression")
sql = f"{sql}, DATA_CONSISTENCY_CHECK={data_consistency_check}"
sql = f"SYSTEM_VERSIONING={on_sql}"
sql = f"{sql})"
return f"{sql})"
return f"WITH({sql})" if expression.args.get("with") else sql
def insert_sql(self, expression: exp.Insert) -> str:
hint = self.sql(expression, "hint")
@ -2300,10 +2346,12 @@ class Generator(metaclass=_Generator):
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
rename = self.expressions(expression, key="rename", flat=True)
rename = f"{self.seg('RENAME')} ({rename})" if rename else ""
return f"*{except_}{replace}{rename}"
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
@ -2843,9 +2891,10 @@ class Generator(metaclass=_Generator):
stack.append(self.expressions(expression, sep=f" {op} "))
else:
stack.append(expression.right)
if expression.comments:
if expression.comments and self.comments:
for comment in expression.comments:
op += f" /*{self.pad_comment(comment)}*/"
if comment:
op += f" /*{self.pad_comment(comment)}*/"
stack.extend((op, expression.left))
return op
@ -2978,6 +3027,19 @@ class Generator(metaclass=_Generator):
return f"ALTER COLUMN {this} DROP DEFAULT"
def alterdiststyle_sql(self, expression: exp.AlterDistStyle) -> str:
this = self.sql(expression, "this")
if not isinstance(expression.this, exp.Var):
this = f"KEY DISTKEY {this}"
return f"ALTER DISTSTYLE {this}"
def altersortkey_sql(self, expression: exp.AlterSortKey) -> str:
compound = " COMPOUND" if expression.args.get("compound") else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f"({expressions})" if expressions else ""
return f"ALTER{compound} SORTKEY {this or expressions}"
def renametable_sql(self, expression: exp.RenameTable) -> str:
if not self.RENAME_TABLE_WITH_DB:
# Remove db from tables
@ -2993,6 +3055,10 @@ class Generator(metaclass=_Generator):
new_column = self.sql(expression, "to")
return f"RENAME COLUMN{exists} {old_column} TO {new_column}"
def alterset_sql(self, expression: exp.AlterSet) -> str:
exprs = self.expressions(expression, flat=True)
return f"SET {exprs}"
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
@ -3006,10 +3072,12 @@ class Generator(metaclass=_Generator):
actions = self.expressions(expression, key="actions", flat=True)
exists = " IF EXISTS" if expression.args.get("exists") else ""
on_cluster = self.sql(expression, "cluster")
on_cluster = f" {on_cluster}" if on_cluster else ""
only = " ONLY" if expression.args.get("only") else ""
options = self.expressions(expression, key="options")
options = f", {options}" if options else ""
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}{options}"
return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')}{on_cluster} {actions}{options}"
def add_column_sql(self, expression: exp.AlterTable) -> str:
if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD:
@ -3781,6 +3849,11 @@ class Generator(metaclass=_Generator):
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
option = self.sql(expression, "this")
if option.upper() == "FILE_FORMAT":
values = self.expressions(expression, key="expression", flat=True, sep=" ")
return f"{option} = ({values})"
value = self.sql(expression, "expression")
if not value:
@ -3802,7 +3875,6 @@ class Generator(metaclass=_Generator):
credentials = f"CREDENTIALS = ({credentials})" if credentials else ""
storage = self.sql(expression, "storage")
storage = f" {storage}" if storage else ""
encryption = self.expressions(expression, key="encryption", flat=True, sep=" ")
encryption = f" ENCRYPTION = ({encryption})" if encryption else ""
@ -3820,13 +3892,40 @@ class Generator(metaclass=_Generator):
this = f" INTO {this}" if self.COPY_HAS_INTO_KEYWORD else f" {this}"
credentials = self.sql(expression, "credentials")
credentials = f" {credentials}" if credentials else ""
kind = " FROM " if expression.args.get("kind") else " TO "
credentials = self.seg(credentials) if credentials else ""
kind = self.seg("FROM" if expression.args.get("kind") else "TO")
files = self.expressions(expression, key="files", flat=True)
sep = ", " if self.dialect.COPY_PARAMS_ARE_CSV else " "
params = self.expressions(expression, key="params", flat=True, sep=sep)
if params:
params = f" WITH ({params})" if self.COPY_PARAMS_ARE_WRAPPED else f" {params}"
params = self.expressions(
expression,
key="params",
sep=sep,
new_line=True,
skip_last=True,
skip_first=True,
indent=self.COPY_PARAMS_ARE_WRAPPED,
)
return f"COPY{this}{kind}{files}{credentials}{params}"
if params:
if self.COPY_PARAMS_ARE_WRAPPED:
params = f" WITH ({params})"
elif not self.pretty:
params = f" {params}"
return f"COPY{this}{kind} {files}{credentials}{params}"
def semicolon_sql(self, expression: exp.Semicolon) -> str:
return ""
def datadeletionproperty_sql(self, expression: exp.DataDeletionProperty) -> str:
on_sql = "ON" if expression.args.get("on") else "OFF"
filter_col: t.Optional[str] = self.sql(expression, "filter_column")
filter_col = f"FILTER_COLUMN={filter_col}" if filter_col else None
retention_period: t.Optional[str] = self.sql(expression, "retention_period")
retention_period = f"RETENTION_PERIOD={retention_period}" if retention_period else None
if filter_col or retention_period:
on_sql = self.func("ON", filter_col, retention_period)
return f"DATA_DELETION={on_sql}"

View file

@ -229,7 +229,9 @@ def to_node(
for source in scope.sources.values():
if isinstance(source, Scope):
source = source.expression
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
node.downstream.append(
Node(name=select.sql(comments=False), source=source, expression=source)
)
# Find all columns that went into creating this one to list their lineage nodes.
source_columns = set(find_all_in_scope(select, exp.Column))
@ -278,7 +280,9 @@ def to_node(
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
node.downstream.append(Node(name=c.sql(), source=source, expression=source))
node.downstream.append(
Node(name=c.sql(comments=False), source=source, expression=source)
)
return node

View file

@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope
from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope
def unnest_subqueries(expression):
@ -64,7 +64,7 @@ def unnest(select, parent_select, next_alias_name):
(not clause or clause_parent_select is not parent_select)
and (
parent_select.args.get("group")
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
)
):
column = exp.Max(this=column)
@ -101,7 +101,7 @@ def unnest(select, parent_select, next_alias_name):
if group:
if {value.this} != set(group.expressions):
select = (
exp.select(exp.column(value.alias, "_q"))
exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
.from_(select.subquery("_q", copy=False), copy=False)
.group_by(exp.column(value.alias, "_q"), copy=False)
)
@ -152,7 +152,9 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
return
is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
node is select.parent
for node in map(lambda s: s.unalias(), parent_select.selects)
if isinstance(node, exp.Subquery)
)
value = select.selects[0]
@ -200,19 +202,25 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
op_type = type(parent_predicate.parent) if parent_predicate else None
if isinstance(parent_predicate, exp.Exists):
alias = exp.column(list(key_aliases.values())[0], table_alias)
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
assert issubclass(op_type, exp.Binary)
predicate = op_type(this=other, expression=exp.column("_x"))
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
)
elif isinstance(parent_predicate, exp.Any):
assert issubclass(op_type, exp.Binary)
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
predicate = op_type(this=other, expression=alias)
parent_predicate = _replace(parent_predicate.parent, predicate)
else:
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
predicate = op_type(this=other, expression=exp.column("_x"))
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
@ -222,7 +230,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
if is_subquery_projection:
if is_subquery_projection and select.parent.alias:
alias = exp.alias_(alias, select.parent.alias)
# COUNT always returns 0 on empty datasets, so we need take that into consideration here
@ -236,10 +244,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
return exp.null()
return node
alias = exp.Coalesce(
this=alias,
expressions=[value.this.transform(remove_aggs)],
)
alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])
select.parent.replace(alias)
@ -249,6 +254,8 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
if is_subquery_projection:
key.replace(nested)
if not isinstance(predicate, exp.EQ):
parent_select.where(predicate, copy=False)
continue
if key in group_by:

View file

@ -61,6 +61,23 @@ def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.LowerHex:
arg = seq_get(args, 0)
return exp.LowerHex(this=arg) if dialect.HEX_LOWERCASE else exp.Hex(this=arg)
def build_lower(args: t.List) -> exp.Lower | exp.Hex:
# LOWER(HEX(..)) can be simplified to LowerHex to simplify its transpilation
arg = seq_get(args, 0)
return exp.LowerHex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Lower(this=arg)
def build_upper(args: t.List) -> exp.Upper | exp.Hex:
# UPPER(HEX(..)) can be simplified to Hex to simplify its transpilation
arg = seq_get(args, 0)
return exp.Hex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Upper(this=arg)
def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
def _builder(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
@ -74,6 +91,17 @@ def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Di
return _builder
def build_mod(args: t.List) -> exp.Mod:
this = seq_get(args, 0)
expression = seq_get(args, 1)
# Wrap the operands if they are binary nodes, e.g. MOD(a + 1, 7) -> (a + 1) % 7
this = exp.Paren(this=this) if isinstance(this, exp.Binary) else this
expression = exp.Paren(this=expression) if isinstance(expression, exp.Binary) else expression
return exp.Mod(this=this, expression=expression)
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@ -123,7 +151,7 @@ class Parser(metaclass=_Parser):
"LOG": build_logarithm,
"LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)),
"LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)),
"MOD": lambda args: exp.Mod(this=seq_get(args, 0), expression=seq_get(args, 1)),
"MOD": build_mod,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@ -137,6 +165,10 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": build_var_map,
"LOWER": build_lower,
"UPPER": build_upper,
"HEX": build_hex,
"TO_HEX": build_hex,
}
NO_PAREN_FUNCTIONS = {
@ -295,22 +327,23 @@ class Parser(metaclass=_Parser):
DB_CREATABLES = {
TokenType.DATABASE,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
TokenType.MODEL,
TokenType.DICTIONARY,
TokenType.MODEL,
TokenType.SCHEMA,
TokenType.SEQUENCE,
TokenType.STORAGE_INTEGRATION,
TokenType.TABLE,
TokenType.TAG,
TokenType.VIEW,
}
CREATABLES = {
TokenType.COLUMN,
TokenType.CONSTRAINT,
TokenType.FOREIGN_KEY,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
TokenType.FOREIGN_KEY,
*DB_CREATABLES,
}
@ -373,6 +406,7 @@ class Parser(metaclass=_Parser):
TokenType.REFRESH,
TokenType.REPLACE,
TokenType.RIGHT,
TokenType.ROLLUP,
TokenType.ROW,
TokenType.ROWS,
TokenType.SEMI,
@ -467,7 +501,6 @@ class Parser(metaclass=_Parser):
}
EQUALITY = {
TokenType.COLON_EQ: exp.PropertyEQ,
TokenType.EQ: exp.EQ,
TokenType.NEQ: exp.NEQ,
TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
@ -653,6 +686,7 @@ class Parser(metaclass=_Parser):
kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False),
this=self._parse_table(schema=False),
),
TokenType.SEMICOLON: lambda self: self.expression(exp.Semicolon),
}
UNARY_PARSERS = {
@ -700,7 +734,12 @@ class Parser(metaclass=_Parser):
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
TokenType.STAR: lambda self, _: self.expression(
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
exp.Star,
**{
"except": self._parse_star_op("EXCEPT", "EXCLUDE"),
"replace": self._parse_star_op("REPLACE"),
"rename": self._parse_star_op("RENAME"),
},
),
}
@ -729,6 +768,9 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
"ALLOWED_VALUES": lambda self: self.expression(
exp.AllowedValuesProperty, expressions=self._parse_csv(self._parse_primary)
),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO": lambda self: self._parse_auto_property(),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
@ -748,6 +790,7 @@ class Parser(metaclass=_Parser):
"CONTAINS": lambda self: self._parse_contains_property(),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DATA_DELETION": lambda self: self._parse_data_deletion_property(),
"DEFINER": lambda self: self._parse_definer(),
"DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
@ -796,6 +839,7 @@ class Parser(metaclass=_Parser):
"READS": lambda self: self._parse_reads_property(),
"REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"STRICT": lambda self: self.expression(exp.StrictProperty),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
"SAMPLE": lambda self: self.expression(
@ -900,6 +944,14 @@ class Parser(metaclass=_Parser):
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
"SET": lambda self: self._parse_alter_table_set(),
}
ALTER_ALTER_PARSERS = {
"DISTKEY": lambda self: self._parse_alter_diststyle(),
"DISTSTYLE": lambda self: self._parse_alter_diststyle(),
"SORTKEY": lambda self: self._parse_alter_sortkey(),
"COMPOUND": lambda self: self._parse_alter_sortkey(compound=True),
}
SCHEMA_UNNAMED_CONSTRAINTS = {
@ -990,6 +1042,8 @@ class Parser(metaclass=_Parser):
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
}
TYPE_CONVERTER: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {}
DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE}
@ -1084,6 +1138,9 @@ class Parser(metaclass=_Parser):
# Whether the table sample clause expects CSV syntax
TABLESAMPLE_CSV = False
# The default method used for table sampling
DEFAULT_SAMPLING_METHOD: t.Optional[str] = None
# Whether the SET command needs a delimiter (e.g. "=") for assignments
SET_REQUIRES_ASSIGNMENT_DELIMITER = True
@ -1228,6 +1285,9 @@ class Parser(metaclass=_Parser):
for i, token in enumerate(raw_tokens):
if token.token_type == TokenType.SEMICOLON:
if token.comments:
chunks.append([token])
if i < total - 1:
chunks.append([])
else:
@ -1471,7 +1531,7 @@ class Parser(metaclass=_Parser):
if self._match_set(self.STATEMENT_PARSERS):
return self.STATEMENT_PARSERS[self._prev.token_type](self)
if self._match_set(Tokenizer.COMMANDS):
if self._match_set(self.dialect.tokenizer.COMMANDS):
return self._parse_command()
expression = self._parse_expression()
@ -1492,6 +1552,8 @@ class Parser(metaclass=_Parser):
schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
)
cluster = self._parse_on_property() if self._match(TokenType.ON) else None
if self._match(TokenType.L_PAREN, advance=False):
expressions = self._parse_wrapped_csv(self._parse_types)
else:
@ -1503,12 +1565,13 @@ class Parser(metaclass=_Parser):
exists=if_exists,
this=table,
expressions=expressions,
kind=kind,
kind=kind.upper(),
temporary=temporary,
materialized=materialized,
cascade=self._match_text_seq("CASCADE"),
constraints=self._match_text_seq("CONSTRAINTS"),
purge=self._match_text_seq("PURGE"),
cluster=cluster,
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@ -1651,7 +1714,7 @@ class Parser(metaclass=_Parser):
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
if self._curr:
if self._curr and not self._match_set((TokenType.R_PAREN, TokenType.COMMA), advance=False):
return self._parse_as_command(start)
return self.expression(
@ -1678,6 +1741,7 @@ class Parser(metaclass=_Parser):
index = self._index
while self._curr:
self._match(TokenType.COMMA)
if self._match_text_seq("INCREMENT"):
self._match_text_seq("BY")
self._match_text_seq("=")
@ -1822,23 +1886,65 @@ class Parser(metaclass=_Parser):
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
def _parse_system_versioning_property(self) -> exp.WithSystemVersioningProperty:
self._match_pair(TokenType.EQ, TokenType.ON)
def _parse_retention_period(self) -> exp.Var:
# Parse TSQL's HISTORY_RETENTION_PERIOD: {INFINITE | <number> DAY | DAYS | MONTH ...}
number = self._parse_number()
number_str = f"{number} " if number else ""
unit = self._parse_var(any_token=True)
return exp.var(f"{number_str}{unit}")
prop = self.expression(exp.WithSystemVersioningProperty)
def _parse_system_versioning_property(
self, with_: bool = False
) -> exp.WithSystemVersioningProperty:
self._match(TokenType.EQ)
prop = self.expression(
exp.WithSystemVersioningProperty,
**{ # type: ignore
"on": True,
"with": with_,
},
)
if self._match_text_seq("OFF"):
prop.set("on", False)
return prop
self._match(TokenType.ON)
if self._match(TokenType.L_PAREN):
self._match_text_seq("HISTORY_TABLE", "=")
prop.set("this", self._parse_table_parts())
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("HISTORY_TABLE", "="):
prop.set("this", self._parse_table_parts())
elif self._match_text_seq("DATA_CONSISTENCY_CHECK", "="):
prop.set("data_consistency", self._advance_any() and self._prev.text.upper())
elif self._match_text_seq("HISTORY_RETENTION_PERIOD", "="):
prop.set("retention_period", self._parse_retention_period())
if self._match(TokenType.COMMA):
self._match_text_seq("DATA_CONSISTENCY_CHECK", "=")
prop.set("expression", self._advance_any() and self._prev.text.upper())
self._match(TokenType.COMMA)
self._match_r_paren()
return prop
def _parse_data_deletion_property(self) -> exp.DataDeletionProperty:
self._match(TokenType.EQ)
on = self._match_text_seq("ON") or not self._match_text_seq("OFF")
prop = self.expression(exp.DataDeletionProperty, on=on)
if self._match(TokenType.L_PAREN):
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILTER_COLUMN", "="):
prop.set("filter_column", self._parse_column())
elif self._match_text_seq("RETENTION_PERIOD", "="):
prop.set("retention_period", self._parse_retention_period())
self._match(TokenType.COMMA)
return prop
def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match_text_seq("(", "SYSTEM_VERSIONING"):
prop = self._parse_system_versioning_property(with_=True)
self._match_r_paren()
return prop
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_properties()
@ -1853,6 +1959,9 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("NO", "DATA"):
return self._parse_withdata(no=True)
if self._match(TokenType.SERDE_PROPERTIES, advance=False):
return self._parse_serde_properties(with_=True)
if not self._next:
return None
@ -2201,6 +2310,7 @@ class Parser(metaclass=_Parser):
def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
null = None
is_table = self._match(TokenType.TABLE)
if is_table:
@ -2214,10 +2324,13 @@ class Parser(metaclass=_Parser):
self.raise_error("Expecting >")
else:
value = self._parse_schema(exp.var("TABLE"))
elif self._match_text_seq("NULL", "ON", "NULL", "INPUT"):
null = True
value = None
else:
value = self._parse_types()
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table, null=null)
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
@ -2340,6 +2453,21 @@ class Parser(metaclass=_Parser):
return None
return self._parse_row_format()
def _parse_serde_properties(self, with_: bool = False) -> t.Optional[exp.SerdeProperties]:
index = self._index
with_ = with_ or self._match_text_seq("WITH")
if not self._match(TokenType.SERDE_PROPERTIES):
self._retreat(index)
return None
return self.expression(
exp.SerdeProperties,
**{ # type: ignore
"expressions": self._parse_wrapped_properties(),
"with": with_,
},
)
def _parse_row_format(
self, match_row: bool = False
) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
@ -2349,11 +2477,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("SERDE"):
this = self._parse_string()
serde_properties = None
if self._match(TokenType.SERDE_PROPERTIES):
serde_properties = self.expression(
exp.SerdeProperties, expressions=self._parse_wrapped_properties()
)
serde_properties = self._parse_serde_properties()
return self.expression(
exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties
@ -2672,9 +2796,7 @@ class Parser(metaclass=_Parser):
)
def _implicit_unnests_to_explicit(self, this: E) -> E:
from sqlglot.optimizer.normalize_identifiers import (
normalize_identifiers as _norm,
)
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm
refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name}
for i, join in enumerate(this.args.get("joins") or []):
@ -3366,6 +3488,9 @@ class Parser(metaclass=_Parser):
elif self._match_texts(("SEED", "REPEATABLE")):
seed = self._parse_wrapped(self._parse_number)
if not method and self.DEFAULT_SAMPLING_METHOD:
method = exp.var(self.DEFAULT_SAMPLING_METHOD)
return self.expression(
exp.TableSample,
expressions=expressions,
@ -3519,7 +3644,11 @@ class Parser(metaclass=_Parser):
elements["all"] = False
while True:
expressions = self._parse_csv(self._parse_conjunction)
expressions = self._parse_csv(
lambda: None
if self._match(TokenType.ROLLUP, advance=False)
else self._parse_conjunction()
)
if expressions:
elements["expressions"].extend(expressions)
@ -3817,7 +3946,24 @@ class Parser(metaclass=_Parser):
return self._parse_alias(self._parse_conjunction())
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_equality, self.CONJUNCTION)
this = self._parse_equality()
if self._match(TokenType.COLON_EQ):
this = self.expression(
exp.PropertyEQ,
this=this,
comments=self._prev_comments,
expression=self._parse_conjunction(),
)
while self._match_set(self.CONJUNCTION):
this = self.expression(
self.CONJUNCTION[self._prev.token_type],
this=this,
comments=self._prev_comments,
expression=self._parse_equality(),
)
return this
def _parse_equality(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_comparison, self.EQUALITY)
@ -4061,6 +4207,7 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
index = self._index
this: t.Optional[exp.Expression] = None
prefix = self._match_text_seq("SYSUDTLIB", ".")
if not self._match_set(self.TYPE_TOKENS):
@ -4081,7 +4228,7 @@ class Parser(metaclass=_Parser):
while self._match(TokenType.DOT):
type_name = f"{type_name}.{self._advance_any() and self._prev.text}"
return exp.DataType.build(type_name, udt=True)
this = exp.DataType.build(type_name, udt=True)
else:
self._retreat(self._index - 1)
return None
@ -4134,7 +4281,6 @@ class Parser(metaclass=_Parser):
maybe_func = True
this: t.Optional[exp.Expression] = None
values: t.Optional[t.List[exp.Expression]] = None
if nested and self._match(TokenType.LT):
@ -4203,10 +4349,17 @@ class Parser(metaclass=_Parser):
values=values,
prefix=prefix,
)
elif expressions:
this.set("expressions", expressions)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
if self.TYPE_CONVERTER and isinstance(this.this, exp.DataType.Type):
converter = self.TYPE_CONVERTER.get(this.this)
if converter:
this = converter(t.cast(exp.DataType, this))
return this
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
@ -4326,9 +4479,7 @@ class Parser(metaclass=_Parser):
if not this and self._match(TokenType.R_PAREN, advance=False):
this = self.expression(exp.Tuple)
elif isinstance(this, exp.UNWRAPPED_QUERIES):
this = self._parse_set_operations(
self._parse_subquery(this=this, parse_alias=False)
)
this = self._parse_subquery(this=this, parse_alias=False)
elif isinstance(this, exp.Subquery):
this = self._parse_subquery(
this=self._parse_set_operations(this), parse_alias=False
@ -5625,13 +5776,8 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
self._match(TokenType.L_BRACE)
this = self._parse_identifier() or self._parse_primary_or_var()
expression = self._match(TokenType.COLON) and (
self._parse_identifier() or self._parse_primary_or_var()
)
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, expression=expression)
return self.expression(exp.Parameter, this=this)
def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PLACEHOLDER_PARSERS):
@ -5641,23 +5787,14 @@ class Parser(metaclass=_Parser):
self._advance(-1)
return None
def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_column)
except_column = self._parse_column()
return [except_column] if except_column else None
def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.REPLACE):
def _parse_star_op(self, *keywords: str) -> t.Optional[t.List[exp.Expression]]:
if not self._match_texts(keywords):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_expression)
replace_expression = self._parse_expression()
return [replace_expression] if replace_expression else None
expression = self._parse_expression()
return [expression] if expression else None
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
@ -5812,7 +5949,12 @@ class Parser(metaclass=_Parser):
return self._parse_wrapped_csv(self._parse_field_def, optional=True)
return self._parse_wrapped_csv(self._parse_add_column, optional=True)
def _parse_alter_table_alter(self) -> exp.AlterColumn:
def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]:
if self._match_texts(self.ALTER_ALTER_PARSERS):
return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self)
# Many dialects support the ALTER [COLUMN] syntax, so if there is no
# keyword after ALTER we default to parsing this statement
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
@ -5833,6 +5975,27 @@ class Parser(metaclass=_Parser):
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
def _parse_alter_diststyle(self) -> exp.AlterDistStyle:
if self._match_texts(("ALL", "EVEN", "AUTO")):
return self.expression(exp.AlterDistStyle, this=exp.var(self._prev.text.upper()))
self._match_text_seq("KEY", "DISTKEY")
return self.expression(exp.AlterDistStyle, this=self._parse_column())
def _parse_alter_sortkey(self, compound: t.Optional[bool] = None) -> exp.AlterSortKey:
if compound:
self._match_text_seq("SORTKEY")
if self._match(TokenType.L_PAREN, advance=False):
return self.expression(
exp.AlterSortKey, expressions=self._parse_wrapped_id_vars(), compound=compound
)
self._match_texts(("AUTO", "NONE"))
return self.expression(
exp.AlterSortKey, this=exp.var(self._prev.text.upper()), compound=compound
)
def _parse_alter_table_drop(self) -> t.List[exp.Expression]:
index = self._index - 1
@ -5858,6 +6021,41 @@ class Parser(metaclass=_Parser):
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
def _parse_alter_table_set(self) -> exp.AlterSet:
alter_set = self.expression(exp.AlterSet)
if self._match(TokenType.L_PAREN, advance=False) or self._match_text_seq(
"TABLE", "PROPERTIES"
):
alter_set.set("expressions", self._parse_wrapped_csv(self._parse_conjunction))
elif self._match_text_seq("FILESTREAM_ON", advance=False):
alter_set.set("expressions", [self._parse_conjunction()])
elif self._match_texts(("LOGGED", "UNLOGGED")):
alter_set.set("option", exp.var(self._prev.text.upper()))
elif self._match_text_seq("WITHOUT") and self._match_texts(("CLUSTER", "OIDS")):
alter_set.set("option", exp.var(f"WITHOUT {self._prev.text.upper()}"))
elif self._match_text_seq("LOCATION"):
alter_set.set("location", self._parse_field())
elif self._match_text_seq("ACCESS", "METHOD"):
alter_set.set("access_method", self._parse_field())
elif self._match_text_seq("TABLESPACE"):
alter_set.set("tablespace", self._parse_field())
elif self._match_text_seq("FILE", "FORMAT") or self._match_text_seq("FILEFORMAT"):
alter_set.set("file_format", [self._parse_field()])
elif self._match_text_seq("STAGE_FILE_FORMAT"):
alter_set.set("file_format", self._parse_wrapped_options())
elif self._match_text_seq("STAGE_COPY_OPTIONS"):
alter_set.set("copy_options", self._parse_wrapped_options())
elif self._match_text_seq("TAG") or self._match_text_seq("TAGS"):
alter_set.set("tag", self._parse_csv(self._parse_conjunction))
else:
if self._match_text_seq("SERDE"):
alter_set.set("serde", self._parse_field())
alter_set.set("expressions", [self._parse_properties()])
return alter_set
def _parse_alter(self) -> exp.AlterTable | exp.Command:
start = self._prev
@ -5867,6 +6065,7 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists()
only = self._match_text_seq("ONLY")
this = self._parse_table(schema=True)
cluster = self._parse_on_property() if self._match(TokenType.ON) else None
if self._next:
self._advance()
@ -5884,6 +6083,7 @@ class Parser(metaclass=_Parser):
actions=actions,
only=only,
options=options,
cluster=cluster,
)
return self._parse_as_command(start)
@ -5974,7 +6174,7 @@ class Parser(metaclass=_Parser):
if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
left = self._parse_primary() or self._parse_column()
assignment_delimiter = self._match_texts(("=", "TO"))
if not left or (self.SET_REQUIRES_ASSIGNMENT_DELIMITER and not assignment_delimiter):

View file

@ -6,7 +6,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.helper import dict_depth, first
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
@ -174,7 +174,7 @@ class AbstractMappingSchema:
return None
if value == TrieResult.PREFIX:
possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
@ -362,14 +362,19 @@ class MappingSchema(AbstractMappingSchema, Schema):
The normalized schema mapping.
"""
normalized_mapping: t.Dict = {}
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
flattened_schema = flatten_schema(schema)
error_msg = "Table {} must match the schema's nesting level: {}."
for keys in flattened_schema:
columns = nested_get(schema, *zip(keys, keys))
if not isinstance(columns, dict):
raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
if isinstance(first(columns.values()), dict):
raise SchemaError(
f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
error_msg.format(
".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0])
),
)
normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
@ -494,16 +499,17 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict:
def flatten_schema(
schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None
) -> t.List[t.List[str]]:
tables = []
keys = keys or []
depth = dict_depth(schema) - 1 if depth is None else depth
for k, v in schema.items():
if depth >= 2:
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
elif depth == 1:
if depth == 1 or not isinstance(v, dict):
tables.append(keys + [k])
elif depth >= 2:
tables.extend(flatten_schema(v, depth - 1, keys + [k]))
return tables

View file

@ -234,6 +234,7 @@ class TokenType(AutoName):
CURRENT_TIME = auto()
CURRENT_TIMESTAMP = auto()
CURRENT_USER = auto()
DECLARE = auto()
DEFAULT = auto()
DELETE = auto()
DESC = auto()
@ -358,6 +359,7 @@ class TokenType(AutoName):
STORAGE_INTEGRATION = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TAG = auto()
TEMPORARY = auto()
TOP = auto()
THEN = auto()
@ -1187,6 +1189,8 @@ class Tokenizer(metaclass=_Tokenizer):
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER:
return self._add(TokenType.NUMBER)
decimal = True
self._advance()
elif self._peek in ("-", "+") and scientific == 1:

View file

@ -105,7 +105,14 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
select.replace(exp.alias_(select, alias))
taken.add(alias)
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
alias_or_name = select.alias_or_name
identifier = select.args.get("alias") or select.this
if isinstance(identifier, exp.Identifier):
return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
return alias_or_name
outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
qualify_filters = expression.args["qualify"].pop().this
expression_by_alias = {
select.alias: select.this
@ -465,19 +472,28 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
"""
top_level_with = expression.args.get("with")
for node in expression.find_all(exp.With):
if node.parent is expression:
for inner_with in expression.find_all(exp.With):
if inner_with.parent is expression:
continue
inner_with = node.pop()
if not top_level_with:
top_level_with = inner_with
top_level_with = inner_with.pop()
expression.set("with", top_level_with)
else:
if inner_with.recursive:
top_level_with.set("recursive", True)
top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions)
parent_cte = inner_with.find_ancestor(exp.CTE)
inner_with.pop()
if parent_cte:
i = top_level_with.expressions.index(parent_cte)
top_level_with.expressions[i:i] = inner_with.expressions
top_level_with.set("expressions", top_level_with.expressions)
else:
top_level_with.set(
"expressions", top_level_with.expressions + inner_with.expressions
)
return expression