Merging upstream version 18.7.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
77523b6777
commit
d1b976f442
96 changed files with 59037 additions and 52828 deletions
|
@ -4,5 +4,10 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
|
||||
# A little hack for backwards compatibility with Python 3.7.
|
||||
# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed.
|
||||
# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency.
|
||||
A = t.TypeVar("A", bound=t.Any)
|
||||
|
||||
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
|
||||
T = t.TypeVar("T")
|
||||
|
|
|
@ -212,7 +212,15 @@ class Column:
|
|||
return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
|
||||
|
||||
def alias(self, name: str) -> Column:
|
||||
new_expression = exp.alias_(self.column_expression, name)
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
dialect = SparkSession().dialect
|
||||
alias: exp.Expression = sqlglot.maybe_parse(name, dialect=dialect)
|
||||
new_expression = exp.alias_(
|
||||
self.column_expression,
|
||||
alias.this if isinstance(alias, exp.Column) else name,
|
||||
dialect=dialect,
|
||||
)
|
||||
return Column(new_expression)
|
||||
|
||||
def asc(self) -> Column:
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
inline_array_sql,
|
||||
json_keyvalue_comma_sql,
|
||||
max_or_greatest,
|
||||
|
@ -176,6 +177,8 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
|
|||
class BigQuery(Dialect):
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
LOG_BASE_FIRST = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
@ -256,7 +259,6 @@ class BigQuery(Dialect):
|
|||
"RECORD": TokenType.STRUCT,
|
||||
"TIMESTAMP": TokenType.TIMESTAMPTZ,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
@ -264,7 +266,6 @@ class BigQuery(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
FUNCTIONS = {
|
||||
|
@ -292,9 +293,7 @@ class BigQuery(Dialect):
|
|||
expression=seq_get(args, 1),
|
||||
position=seq_get(args, 2),
|
||||
occurrence=seq_get(args, 3),
|
||||
group=exp.Literal.number(1)
|
||||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
group=exp.Literal.number(1) if re.compile(args[1].name).groups == 1 else None,
|
||||
),
|
||||
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
|
||||
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
|
||||
|
@ -344,6 +343,11 @@ class BigQuery(Dialect):
|
|||
"OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy()
|
||||
RANGE_PARSERS.pop(TokenType.OVERLAPS, None)
|
||||
|
||||
NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN}
|
||||
|
||||
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table_part(schema=schema) or self._parse_number()
|
||||
|
||||
|
@ -413,8 +417,8 @@ class BigQuery(Dialect):
|
|||
TABLE_HINTS = False
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
ESCAPE_LINE_BREAK = True
|
||||
NVL2_SUPPORTED = False
|
||||
UNNEST_WITH_ORDINALITY = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -434,6 +438,7 @@ class BigQuery(Dialect):
|
|||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql(false_value="NULL"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
|
@ -455,10 +460,11 @@ class BigQuery(Dialect):
|
|||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.explode_to_unnest,
|
||||
transforms.explode_to_unnest(),
|
||||
_unqualify_unnest,
|
||||
transforms.eliminate_distinct_on,
|
||||
_alias_ordered_group,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
]
|
||||
),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
|
@ -514,6 +520,18 @@ class BigQuery(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore
|
||||
{
|
||||
"\a": "\\a",
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
"\v": "\\v",
|
||||
}
|
||||
)
|
||||
|
||||
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
|
||||
RESERVED_KEYWORDS = {
|
||||
*generator.Generator.RESERVED_KEYWORDS,
|
||||
|
|
|
@ -113,15 +113,11 @@ class ClickHouse(Dialect):
|
|||
*parser.Parser.JOIN_KINDS,
|
||||
TokenType.ANY,
|
||||
TokenType.ASOF,
|
||||
TokenType.ANTI,
|
||||
TokenType.SEMI,
|
||||
TokenType.ARRAY,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.ANY,
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
TokenType.SETTINGS,
|
||||
TokenType.FORMAT,
|
||||
TokenType.ARRAY,
|
||||
|
|
|
@ -51,8 +51,6 @@ class Databricks(Spark):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
HEX_STRINGS = []
|
||||
|
||||
|
|
|
@ -125,6 +125,12 @@ class _Dialect(type):
|
|||
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
|
||||
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
|
||||
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
|
||||
TokenType.ANTI,
|
||||
TokenType.SEMI,
|
||||
}
|
||||
|
||||
klass.generator_class.can_identify = klass.can_identify
|
||||
|
||||
return klass
|
||||
|
@ -156,9 +162,15 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not user-defined data types are supported
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
|
||||
# Determines whether or not SEMI/ANTI JOINs are supported
|
||||
SUPPORTS_SEMI_ANTI_JOIN = True
|
||||
|
||||
# Determines how function names are going to be normalized
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
|
||||
# Determines whether the base comes first in the LOG function
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
# Indicates the default null ordering method to use if not explicitly set
|
||||
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
@ -331,10 +343,18 @@ def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -
|
|||
return self.func("APPROX_COUNT_DISTINCT", expression.this)
|
||||
|
||||
|
||||
def if_sql(self: Generator, expression: exp.If) -> str:
|
||||
return self.func(
|
||||
"IF", expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
def if_sql(
|
||||
name: str = "IF", false_value: t.Optional[exp.Expression | str] = None
|
||||
) -> t.Callable[[Generator, exp.If], str]:
|
||||
def _if_sql(self: Generator, expression: exp.If) -> str:
|
||||
return self.func(
|
||||
name,
|
||||
expression.this,
|
||||
expression.args.get("true"),
|
||||
expression.args.get("false") or false_value,
|
||||
)
|
||||
|
||||
return _if_sql
|
||||
|
||||
|
||||
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||
|
@ -751,6 +771,12 @@ def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
|
|||
return self.func("MAX", expression.this)
|
||||
|
||||
|
||||
def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
|
||||
a = self.sql(expression.left)
|
||||
b = self.sql(expression.right)
|
||||
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
|
||||
|
||||
|
||||
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
|
||||
def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
|
||||
|
@ -764,3 +790,10 @@ def is_parse_json(expression: exp.Expression) -> bool:
|
|||
|
||||
def isnull_to_is_null(args: t.List) -> exp.Expression:
|
||||
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
|
||||
|
||||
|
||||
def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
|
||||
if expression.expression.args.get("with"):
|
||||
expression = expression.copy()
|
||||
expression.set("with", expression.expression.args["with"].pop())
|
||||
return self.insert_sql(expression)
|
||||
|
|
|
@ -40,6 +40,7 @@ class Drill(Dialect):
|
|||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
|
@ -135,7 +136,9 @@ class Drill(Dialect):
|
|||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.Pow: rename_func("POW"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
binary_from_function,
|
||||
bool_xor_sql,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
encode_decode_sql,
|
||||
|
@ -190,6 +191,11 @@ class DuckDB(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -224,6 +230,7 @@ class DuckDB(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
NVL2_SUPPORTED = False
|
||||
SEMI_ANTI_JOIN_WITH_SIDE = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -234,7 +241,7 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression),
|
||||
exp.BitwiseXor: rename_func("XOR"),
|
||||
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
|
||||
exp.CurrentDate: lambda self, e: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda self, e: "CURRENT_TIME",
|
||||
|
@ -301,6 +308,7 @@ class DuckDB(Dialect):
|
|||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
|
|
@ -111,7 +111,7 @@ def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
|
|||
|
||||
|
||||
def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
|
||||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
|
||||
|
@ -413,7 +413,7 @@ class Hive(Dialect):
|
|||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
|
@ -466,6 +466,11 @@ class Hive(Dialect):
|
|||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||
exp.National: lambda self, e: self.national_sql(e, prefix=""),
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NotForReplicationColumnConstraint: lambda self, e: "",
|
||||
exp.OnProperty: lambda self, e: "",
|
||||
exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -475,6 +480,35 @@ class Hive(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
parent = expression.parent
|
||||
|
||||
if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem):
|
||||
# We need to produce SET key = value instead of SET ${key} = value
|
||||
return this
|
||||
|
||||
return f"${{{this}}}"
|
||||
|
||||
def schema_sql(self, expression: exp.Schema) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for ordered in expression.find_all(exp.Ordered):
|
||||
if ordered.args.get("desc") is False:
|
||||
ordered.set("desc", None)
|
||||
|
||||
return super().schema_sql(expression)
|
||||
|
||||
def constraint_sql(self, expression: exp.Constraint) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for prop in list(expression.find_all(exp.Properties)):
|
||||
prop.pop()
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, sep=" ", flat=True)
|
||||
return f"CONSTRAINT {this} {expressions}"
|
||||
|
||||
def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str:
|
||||
serde_props = self.sql(expression, "serde_properties")
|
||||
serde_props = f" {serde_props}" if serde_props else ""
|
||||
|
|
|
@ -102,6 +102,7 @@ class MySQL(Dialect):
|
|||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
TIME_MAPPING = {
|
||||
|
@ -519,7 +520,7 @@ class MySQL(Dialect):
|
|||
|
||||
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
|
||||
|
||||
def _parse_type(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
|
||||
# mysql binary is special and can work anywhere, even in order by operations
|
||||
# it operates like a no paren func
|
||||
if self._match(TokenType.BINARY, advance=False):
|
||||
|
@ -528,7 +529,7 @@ class MySQL(Dialect):
|
|||
if isinstance(data_type, exp.DataType):
|
||||
return self.expression(exp.Cast, this=self._parse_column(), to=data_type)
|
||||
|
||||
return super()._parse_type()
|
||||
return super()._parse_type(parse_interval=parse_interval)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
@ -560,7 +561,9 @@ class MySQL(Dialect):
|
|||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
bool_xor_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
|
@ -110,7 +111,7 @@ def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> st
|
|||
|
||||
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return f"{self.expressions(expression, flat=True)}[]" if expression.expressions else "ARRAY"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
|
@ -380,25 +381,29 @@ class Postgres(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
|
||||
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
|
||||
exp.Pow: lambda self, e: self.binary(e, "^"),
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Max: max_or_greatest,
|
||||
|
@ -412,8 +417,10 @@ class Postgres(Dialect):
|
|||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Pow: lambda self, e: self.binary(e, "^"),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Substring: _substring_sql,
|
||||
|
@ -426,11 +433,7 @@ class Postgres(Dialect):
|
|||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
|
||||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
binary_from_function,
|
||||
bool_xor_sql,
|
||||
date_trunc_to_time,
|
||||
encode_decode_sql,
|
||||
format_time_lambda,
|
||||
|
@ -40,7 +41,7 @@ def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> s
|
|||
this=exp.Unnest(
|
||||
expressions=[expression.this.this],
|
||||
alias=expression.args.get("alias"),
|
||||
ordinality=isinstance(expression.this, exp.Posexplode),
|
||||
offset=isinstance(expression.this, exp.Posexplode),
|
||||
),
|
||||
kind="cross",
|
||||
)
|
||||
|
@ -173,6 +174,7 @@ class Presto(Dialect):
|
|||
TIME_FORMAT = MySQL.TIME_FORMAT
|
||||
TIME_MAPPING = MySQL.TIME_MAPPING
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
# https://github.com/trinodb/trino/issues/17
|
||||
# https://github.com/trinodb/trino/issues/12289
|
||||
|
@ -308,7 +310,7 @@ class Presto(Dialect):
|
|||
exp.First: _first_last_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
|
@ -331,7 +333,8 @@ class Presto(Dialect):
|
|||
[
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest,
|
||||
transforms.explode_to_unnest(1),
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
]
|
||||
),
|
||||
exp.SortArray: _no_sort_array,
|
||||
|
@ -340,7 +343,6 @@ class Presto(Dialect):
|
|||
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: rename_func("ROW"),
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
|
@ -363,8 +365,16 @@ class Presto(Dialect):
|
|||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]),
|
||||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
def struct_sql(self, expression: exp.Struct) -> str:
|
||||
if any(isinstance(arg, (exp.EQ, exp.Slice)) for arg in expression.expressions):
|
||||
self.unsupported("Struct with key-value definitions is unsupported.")
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
return rename_func("ROW")(self, expression)
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
if expression.this and unit.lower().startswith("week"):
|
||||
|
|
|
@ -138,7 +138,9 @@ class Redshift(Postgres):
|
|||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
|
||||
}
|
||||
|
|
|
@ -5,9 +5,11 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
binary_from_function,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
inline_array_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -203,6 +205,7 @@ class Snowflake(Dialect):
|
|||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"YYYY": "%Y",
|
||||
|
@ -240,7 +243,16 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries(
|
||||
# ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive
|
||||
start=seq_get(args, 0),
|
||||
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
step=seq_get(args, 2),
|
||||
),
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"BITXOR": binary_from_function(exp.BitwiseXor),
|
||||
"BIT_XOR": binary_from_function(exp.BitwiseXor),
|
||||
"BOOLXOR": binary_from_function(exp.Xor),
|
||||
"CONVERT_TIMEZONE": _parse_convert_timezone,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
|
@ -277,7 +289,7 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME}
|
||||
TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
|
@ -381,6 +393,7 @@ class Snowflake(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
AGGREGATE_FILTER_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -390,6 +403,7 @@ class Snowflake(Dialect):
|
|||
exp.AtTimeZone: lambda self, e: self.func(
|
||||
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
|
||||
),
|
||||
exp.BitwiseXor: rename_func("BITXOR"),
|
||||
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.text("unit"), e.expression, e.this
|
||||
|
@ -398,8 +412,11 @@ class Snowflake(Dialect):
|
|||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Extract: rename_func("DATE_PART"),
|
||||
exp.GenerateSeries: lambda self, e: self.func(
|
||||
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
|
||||
),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.If: if_sql(name="IFF", false_value="NULL"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -407,7 +424,13 @@ class Snowflake(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.RegexpILike: _regexpilike_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.explode_to_unnest(0),
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
]
|
||||
),
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.StrPosition: lambda self, e: self.func(
|
||||
|
@ -431,6 +454,7 @@ class Snowflake(Dialect):
|
|||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.Xor: rename_func("BOOLXOR"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -449,6 +473,27 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
selects = ["value"]
|
||||
unnest_alias = expression.args.get("alias")
|
||||
|
||||
offset = expression.args.get("offset")
|
||||
if offset:
|
||||
if unnest_alias:
|
||||
expression = expression.copy()
|
||||
unnest_alias.append("columns", offset.pop())
|
||||
|
||||
selects.append("index")
|
||||
|
||||
subquery = exp.Subquery(
|
||||
this=exp.select(*selects).from_(
|
||||
f"TABLE(FLATTEN(INPUT => {self.sql(expression.expressions[0])}))"
|
||||
),
|
||||
)
|
||||
alias = self.sql(unnest_alias)
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
return f"{self.sql(subquery)}{alias}"
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
scope = self.sql(expression, "scope")
|
||||
scope = f" {scope}" if scope else ""
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
is_parse_json,
|
||||
move_insert_cte_sql,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
|
@ -115,13 +116,6 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str:
|
||||
if expression.expression.args.get("with"):
|
||||
expression = expression.copy()
|
||||
expression.set("with", expression.expression.args.pop("with"))
|
||||
return self.insert_sql(expression)
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
|
@ -206,7 +200,7 @@ class Spark2(Hive):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.Insert: _insert_sql,
|
||||
exp.Insert: move_insert_cte_sql,
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
|
|
|
@ -64,6 +64,7 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:
|
|||
class SQLite(Dialect):
|
||||
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
|
@ -125,7 +126,11 @@ class SQLite(Dialect):
|
|||
exp.Pivot: no_pivot_sql,
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_qualify,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
]
|
||||
),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
|
|
|
@ -8,6 +8,8 @@ from sqlglot.tokens import TokenType
|
|||
|
||||
|
||||
class Teradata(Dialect):
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
|
@ -168,7 +170,9 @@ class Teradata(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
any_value_to_max_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
move_insert_cte_sql,
|
||||
parse_date_delta,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
|
@ -206,6 +207,8 @@ class TSQL(Dialect):
|
|||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
LOG_BASE_FIRST = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"year": "%Y",
|
||||
|
@ -345,6 +348,8 @@ class TSQL(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CHARINDEX": lambda args: exp.StrPosition(
|
||||
|
@ -396,7 +401,6 @@ class TSQL(Dialect):
|
|||
TokenType.END: lambda self: self._parse_command(),
|
||||
}
|
||||
|
||||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
@ -609,11 +613,14 @@ class TSQL(Dialect):
|
|||
exp.Extract: rename_func("DATEPART"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.Insert: move_insert_cte_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
"HASHBYTES",
|
||||
|
@ -632,6 +639,14 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def setitem_sql(self, expression: exp.SetItem) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):
|
||||
# T-SQL does not use '=' in SET command, except when the LHS is a variable.
|
||||
return f"{self.sql(this.left)} {self.sql(this.right)}"
|
||||
|
||||
return super().setitem_sql(expression)
|
||||
|
||||
def boolean_sql(self, expression: exp.Boolean) -> str:
|
||||
if type(expression.parent) in BIT_TYPES:
|
||||
return "1" if expression.this else "0"
|
||||
|
@ -661,16 +676,27 @@ class TSQL(Dialect):
|
|||
exists = expression.args.pop("exists", None)
|
||||
sql = super().create_sql(expression)
|
||||
|
||||
table = expression.find(exp.Table)
|
||||
|
||||
if kind == "TABLE" and expression.expression:
|
||||
sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp"
|
||||
|
||||
if exists:
|
||||
table = expression.find(exp.Table)
|
||||
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
|
||||
sql = self.sql(exp.Literal.string(sql))
|
||||
if kind == "SCHEMA":
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC({sql})"""
|
||||
elif kind == "TABLE":
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
|
||||
assert table
|
||||
where = exp.and_(
|
||||
exp.column("table_name").eq(table.name),
|
||||
exp.column("table_schema").eq(table.db) if table.db else None,
|
||||
exp.column("table_catalog").eq(table.catalog) if table.catalog else None,
|
||||
)
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE {where}) EXEC({sql})"""
|
||||
elif kind == "INDEX":
|
||||
index = self.sql(exp.Literal.string(expression.this.text("this")))
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC({sql})"""
|
||||
elif expression.args.get("replace"):
|
||||
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
|
||||
|
||||
|
|
|
@ -664,16 +664,6 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
return load(obj)
|
||||
|
||||
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
ExpOrStr = t.Union[str, Expression]
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
def and_(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
|
@ -762,11 +752,19 @@ class Condition(Expression):
|
|||
return klass(this=other, expression=this)
|
||||
return klass(this=this, expression=other)
|
||||
|
||||
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
|
||||
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]) -> Bracket:
|
||||
return Bracket(
|
||||
this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
|
||||
)
|
||||
|
||||
def __iter__(self) -> t.Iterator:
|
||||
if "expressions" in self.arg_types:
|
||||
return iter(self.args.get("expressions") or [])
|
||||
# We define this because __getitem__ converts Expression into an iterable, which is
|
||||
# problematic because one can hit infinite loops if they do "for x in some_expr: ..."
|
||||
# See: https://peps.python.org/pep-0234/
|
||||
raise TypeError(f"'{self.__class__.__name__}' object is not iterable")
|
||||
|
||||
def isin(
|
||||
self,
|
||||
*expressions: t.Any,
|
||||
|
@ -886,6 +884,18 @@ class Condition(Expression):
|
|||
return not_(self.copy())
|
||||
|
||||
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
ExpOrStr = t.Union[str, Expression]
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
"""Logical conditions like x AND y, or simply x"""
|
||||
|
||||
|
||||
class Predicate(Condition):
|
||||
"""Relationships like x = y, x > 1, x >= y."""
|
||||
|
||||
|
@ -1045,6 +1055,10 @@ class Describe(Expression):
|
|||
arg_types = {"this": True, "kind": False, "expressions": False}
|
||||
|
||||
|
||||
class Kill(Expression):
|
||||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Pragma(Expression):
|
||||
pass
|
||||
|
||||
|
@ -1161,7 +1175,7 @@ class Column(Condition):
|
|||
if self.args.get(part)
|
||||
]
|
||||
|
||||
def to_dot(self) -> Dot:
|
||||
def to_dot(self) -> Dot | Identifier:
|
||||
"""Converts the column into a dot expression."""
|
||||
parts = self.parts
|
||||
parent = self.parent
|
||||
|
@ -1171,7 +1185,7 @@ class Column(Condition):
|
|||
parts.append(parent.expression)
|
||||
parent = parent.parent
|
||||
|
||||
return Dot.build(deepcopy(parts))
|
||||
return Dot.build(deepcopy(parts)) if len(parts) > 1 else parts[0]
|
||||
|
||||
|
||||
class ColumnPosition(Expression):
|
||||
|
@ -1607,6 +1621,7 @@ class Index(Expression):
|
|||
"primary": False,
|
||||
"amp": False, # teradata
|
||||
"partition_by": False, # teradata
|
||||
"where": False, # postgres partial indexes
|
||||
}
|
||||
|
||||
|
||||
|
@ -1917,7 +1932,7 @@ class Sort(Order):
|
|||
|
||||
|
||||
class Ordered(Expression):
|
||||
arg_types = {"this": True, "desc": True, "nulls_first": True}
|
||||
arg_types = {"this": True, "desc": False, "nulls_first": True}
|
||||
|
||||
|
||||
class Property(Expression):
|
||||
|
@ -2569,7 +2584,6 @@ class Intersect(Union):
|
|||
class Unnest(UDTF):
|
||||
arg_types = {
|
||||
"expressions": True,
|
||||
"ordinality": False,
|
||||
"alias": False,
|
||||
"offset": False,
|
||||
}
|
||||
|
@ -2862,6 +2876,7 @@ class Select(Subqueryable):
|
|||
prefix="LIMIT",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
@ -4007,6 +4022,10 @@ class TimeUnit(Expression):
|
|||
|
||||
super().__init__(**args)
|
||||
|
||||
@property
|
||||
def unit(self) -> t.Optional[Var]:
|
||||
return self.args.get("unit")
|
||||
|
||||
|
||||
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
|
||||
# https://trino.io/docs/current/language/types.html#interval-day-to-second
|
||||
|
@ -4018,10 +4037,6 @@ class IntervalSpan(Expression):
|
|||
class Interval(TimeUnit):
|
||||
arg_types = {"this": False, "unit": False}
|
||||
|
||||
@property
|
||||
def unit(self) -> t.Optional[Var]:
|
||||
return self.args.get("unit")
|
||||
|
||||
|
||||
class IgnoreNulls(Expression):
|
||||
pass
|
||||
|
@ -4327,6 +4342,10 @@ class DateDiff(Func, TimeUnit):
|
|||
class DateTrunc(Func):
|
||||
arg_types = {"unit": True, "this": True, "zone": False}
|
||||
|
||||
@property
|
||||
def unit(self) -> Expression:
|
||||
return self.args["unit"]
|
||||
|
||||
|
||||
class DatetimeAdd(Func, TimeUnit):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
@ -4427,7 +4446,8 @@ class DateToDi(Func):
|
|||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date
|
||||
class Date(Func):
|
||||
arg_types = {"this": True, "zone": False}
|
||||
arg_types = {"this": False, "zone": False, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Day(Func):
|
||||
|
@ -5131,10 +5151,11 @@ def _apply_builder(
|
|||
prefix=None,
|
||||
into=None,
|
||||
dialect=None,
|
||||
into_arg="this",
|
||||
**opts,
|
||||
):
|
||||
if _is_wrong_expression(expression, into):
|
||||
expression = into(this=expression)
|
||||
expression = into(**{into_arg: expression})
|
||||
instance = maybe_copy(instance, copy)
|
||||
expression = maybe_parse(
|
||||
sql_or_expression=expression,
|
||||
|
@ -5926,7 +5947,10 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca
|
|||
The new Cast instance.
|
||||
"""
|
||||
expression = maybe_parse(expression, **opts)
|
||||
return Cast(this=expression, to=DataType.build(to, **opts))
|
||||
data_type = DataType.build(to, **opts)
|
||||
expression = Cast(this=expression, to=data_type)
|
||||
expression.type = data_type
|
||||
return expression
|
||||
|
||||
|
||||
def table_(
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
import typing as t
|
||||
from collections import defaultdict
|
||||
from functools import reduce
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
|
@ -99,6 +100,9 @@ class Generator:
|
|||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
# Whether the base comes first
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
# Whether or not null ordering is supported in order by
|
||||
NULL_ORDERING_SUPPORTED = True
|
||||
|
||||
|
@ -188,6 +192,18 @@ class Generator:
|
|||
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
|
||||
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
|
||||
UNNEST_WITH_ORDINALITY = True
|
||||
|
||||
# Whether or not FILTER (WHERE cond) can be used for conditional aggregation
|
||||
AGGREGATE_FILTER_SUPPORTED = True
|
||||
|
||||
# Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
|
||||
SEMI_ANTI_JOIN_WITH_SIDE = True
|
||||
|
||||
# Whether or not session variables / parameters are supported, e.g. @x in T-SQL
|
||||
SUPPORTS_PARAMETERS = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -308,6 +324,8 @@ class Generator:
|
|||
exp.Paren,
|
||||
)
|
||||
|
||||
UNESCAPED_SEQUENCE_TABLE = None # type: ignore
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
# Autofilled
|
||||
|
@ -320,7 +338,6 @@ class Generator:
|
|||
STRICT_STRING_CONCAT = False
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
ESCAPE_LINE_BREAK = False
|
||||
|
||||
can_identify: t.Callable[[str, str | bool], bool]
|
||||
|
||||
|
@ -955,9 +972,16 @@ class Generator:
|
|||
return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
|
||||
|
||||
def filter_sql(self, expression: exp.Filter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
where = self.sql(expression, "expression").strip()
|
||||
return f"{this} FILTER({where})"
|
||||
if self.AGGREGATE_FILTER_SUPPORTED:
|
||||
this = self.sql(expression, "this")
|
||||
where = self.sql(expression, "expression").strip()
|
||||
return f"{this} FILTER({where})"
|
||||
|
||||
agg = expression.this.copy()
|
||||
agg_arg = agg.this
|
||||
cond = expression.expression.this
|
||||
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
|
||||
return self.sql(agg)
|
||||
|
||||
def hint_sql(self, expression: exp.Hint) -> str:
|
||||
if not self.QUERY_HINTS:
|
||||
|
@ -975,13 +999,14 @@ class Generator:
|
|||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table}" if table else ""
|
||||
using = self.sql(expression, "using")
|
||||
using = f" USING {using} " if using else ""
|
||||
using = f" USING {using}" if using else ""
|
||||
index = "INDEX " if not table else ""
|
||||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
partition_by = self.expressions(expression, key="partition_by", flat=True)
|
||||
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
|
||||
where = self.sql(expression, "where")
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}{where}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
|
@ -1060,10 +1085,15 @@ class Generator:
|
|||
|
||||
return properties_locs
|
||||
|
||||
def property_name(self, expression: exp.Property, string_key: bool = False) -> str:
|
||||
if isinstance(expression.this, exp.Dot):
|
||||
return self.sql(expression, "this")
|
||||
return f"'{expression.name}'" if string_key else expression.name
|
||||
|
||||
def property_sql(self, expression: exp.Property) -> str:
|
||||
property_cls = expression.__class__
|
||||
if property_cls == exp.Property:
|
||||
return f"{expression.name}={self.sql(expression, 'value')}"
|
||||
return f"{self.property_name(expression)}={self.sql(expression, 'value')}"
|
||||
|
||||
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
||||
if not property_name:
|
||||
|
@ -1224,6 +1254,13 @@ class Generator:
|
|||
def introducer_sql(self, expression: exp.Introducer) -> str:
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def kill_sql(self, expression: exp.Kill) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f" {kind}" if kind else ""
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
return f"KILL{kind}{this}"
|
||||
|
||||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
|
@ -1386,13 +1423,11 @@ class Generator:
|
|||
return f"{values} AS {alias}" if alias else values
|
||||
|
||||
# Converts `VALUES...` expression into a series of select unions.
|
||||
# Note: If you have a lot of unions then this will result in a large number of recursive statements to
|
||||
# evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
|
||||
# very slow.
|
||||
expression = expression.copy()
|
||||
column_names = expression.alias and expression.args["alias"].columns
|
||||
alias_node = expression.args.get("alias")
|
||||
column_names = alias_node and alias_node.columns
|
||||
|
||||
selects = []
|
||||
selects: t.List[exp.Subqueryable] = []
|
||||
|
||||
for i, tup in enumerate(expression.expressions):
|
||||
row = tup.expressions
|
||||
|
@ -1404,14 +1439,18 @@ class Generator:
|
|||
|
||||
selects.append(exp.Select(expressions=row))
|
||||
|
||||
subquery_expression: exp.Select | exp.Union = selects[0]
|
||||
if len(selects) > 1:
|
||||
for select in selects[1:]:
|
||||
subquery_expression = exp.union(
|
||||
subquery_expression, select, distinct=False, copy=False
|
||||
)
|
||||
if self.pretty:
|
||||
# This may result in poor performance for large-cardinality `VALUES` tables, due to
|
||||
# the deep nesting of the resulting exp.Unions. If this is a problem, either increase
|
||||
# `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`.
|
||||
subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects)
|
||||
return self.subquery_sql(
|
||||
subqueryable.subquery(alias_node and alias_node.this, copy=False)
|
||||
)
|
||||
|
||||
return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False))
|
||||
alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else ""
|
||||
unions = " UNION ALL ".join(self.sql(select) for select in selects)
|
||||
return f"({unions}){alias}"
|
||||
|
||||
def var_sql(self, expression: exp.Var) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
@ -1477,12 +1516,17 @@ class Generator:
|
|||
return f"PRIOR {self.sql(expression, 'this')}"
|
||||
|
||||
def join_sql(self, expression: exp.Join) -> str:
|
||||
if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ("SEMI", "ANTI"):
|
||||
side = None
|
||||
else:
|
||||
side = expression.side
|
||||
|
||||
op_sql = " ".join(
|
||||
op
|
||||
for op in (
|
||||
expression.method,
|
||||
"GLOBAL" if expression.args.get("global") else None,
|
||||
expression.side,
|
||||
side,
|
||||
expression.kind,
|
||||
expression.hint if self.JOIN_HINTS else None,
|
||||
)
|
||||
|
@ -1594,8 +1638,8 @@ class Generator:
|
|||
|
||||
def escape_str(self, text: str) -> str:
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.ESCAPE_LINE_BREAK:
|
||||
text = text.replace("\n", "\\n")
|
||||
if self.UNESCAPED_SEQUENCE_TABLE:
|
||||
text = text.translate(self.UNESCAPED_SEQUENCE_TABLE)
|
||||
elif self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return text
|
||||
|
@ -1643,7 +1687,7 @@ class Generator:
|
|||
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
|
||||
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
|
||||
|
||||
sort_order = " DESC" if desc else ""
|
||||
sort_order = " DESC" if desc else (" ASC" if desc is False else "")
|
||||
nulls_sort_change = ""
|
||||
if nulls_first and (
|
||||
(asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
|
||||
|
@ -1817,8 +1861,7 @@ class Generator:
|
|||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
|
||||
return f"{self.PARAMETER_TOKEN}{this}"
|
||||
return f"{self.PARAMETER_TOKEN}{this}" if self.SUPPORTS_PARAMETERS else this
|
||||
|
||||
def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1858,17 +1901,33 @@ class Generator:
|
|||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
|
||||
alias = expression.args.get("alias")
|
||||
offset = expression.args.get("offset")
|
||||
|
||||
if self.UNNEST_WITH_ORDINALITY:
|
||||
if alias and isinstance(offset, exp.Expression):
|
||||
alias = alias.copy()
|
||||
alias.append("columns", offset.copy())
|
||||
|
||||
if alias and self.UNNEST_COLUMN_ONLY:
|
||||
columns = alias.columns
|
||||
alias = self.sql(columns[0]) if columns else ""
|
||||
else:
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = self.sql(alias)
|
||||
|
||||
alias = f" AS {alias}" if alias else alias
|
||||
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
|
||||
offset = expression.args.get("offset")
|
||||
offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
|
||||
return f"UNNEST({args}){ordinality}{alias}{offset}"
|
||||
if self.UNNEST_WITH_ORDINALITY:
|
||||
suffix = f" WITH ORDINALITY{alias}" if offset else alias
|
||||
else:
|
||||
if isinstance(offset, exp.Expression):
|
||||
suffix = f"{alias} WITH OFFSET AS {self.sql(offset)}"
|
||||
elif offset:
|
||||
suffix = f"{alias} WITH OFFSET"
|
||||
else:
|
||||
suffix = alias
|
||||
|
||||
return f"UNNEST({args}){suffix}"
|
||||
|
||||
def where_sql(self, expression: exp.Where) -> str:
|
||||
this = self.indent(self.sql(expression, "this"))
|
||||
|
@ -2471,6 +2530,12 @@ class Generator:
|
|||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="TRY_")
|
||||
|
||||
def log_sql(self, expression: exp.Log) -> str:
|
||||
args = list(expression.args.values())
|
||||
if not self.LOG_BASE_FIRST:
|
||||
args.reverse()
|
||||
return self.func("LOG", *args)
|
||||
|
||||
def use_sql(self, expression: exp.Use) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f" {kind}" if kind else ""
|
||||
|
|
|
@ -13,9 +13,10 @@ from itertools import count
|
|||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E, T
|
||||
from sqlglot._typing import A, E, T
|
||||
from sqlglot.expressions import Expression
|
||||
|
||||
|
||||
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
|
||||
PYTHON_VERSION = sys.version_info[:2]
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
@ -379,7 +380,9 @@ def is_iterable(value: t.Any) -> bool:
|
|||
Returns:
|
||||
A `bool` value indicating if it is an iterable.
|
||||
"""
|
||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
|
||||
from sqlglot import Expression
|
||||
|
||||
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression))
|
||||
|
||||
|
||||
def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
|
||||
|
@ -435,3 +438,22 @@ def dict_depth(d: t.Dict) -> int:
|
|||
def first(it: t.Iterable[T]) -> T:
|
||||
"""Returns the first element from an iterable (useful for sets)."""
|
||||
return next(i for i in it)
|
||||
|
||||
|
||||
def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]:
|
||||
if not ranges:
|
||||
return []
|
||||
|
||||
ranges = sorted(ranges)
|
||||
|
||||
merged = [ranges[0]]
|
||||
|
||||
for start, end in ranges[1:]:
|
||||
last_start, last_end = merged[-1]
|
||||
|
||||
if start <= last_end:
|
||||
merged[-1] = (last_start, max(last_end, end))
|
||||
else:
|
||||
merged.append((start, end))
|
||||
|
||||
return merged
|
||||
|
|
|
@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.Timestamp,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
|
@ -177,6 +178,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.SafeConcat,
|
||||
exp.SafeDPipe,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
|
@ -242,6 +244,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
|
||||
self._visited: t.Set[int] = set()
|
||||
|
||||
def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
|
||||
expression.type = target_type
|
||||
self._visited.add(id(expression))
|
||||
|
||||
def annotate(self, expression: E) -> E:
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
|
@ -279,9 +288,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
self._set_type(col, self.schema.get_column_type(source, col))
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
col.type = selects[col.table][col.name].type
|
||||
self._set_type(col, selects[col.table][col.name].type)
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
@ -289,7 +298,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if expression.type:
|
||||
if id(expression) in self._visited:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
annotator = self.annotators.get(expression.__class__)
|
||||
|
@ -338,17 +347,18 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
expression.type = exp.DataType.Type.NULL
|
||||
self._set_type(expression, exp.DataType.Type.NULL)
|
||||
elif exp.DataType.Type.NULL in (left_type, right_type):
|
||||
expression.type = exp.DataType.build(
|
||||
"NULLABLE", expressions=exp.DataType.build("BOOLEAN")
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")),
|
||||
)
|
||||
else:
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
else:
|
||||
expression.type = self._maybe_coerce(left_type, right_type)
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -357,26 +367,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
|
||||
expression.type = exp.DataType.Type.BOOLEAN
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
else:
|
||||
expression.type = expression.this.type
|
||||
self._set_type(expression, expression.this.type)
|
||||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
||||
if expression.is_string:
|
||||
expression.type = exp.DataType.Type.VARCHAR
|
||||
self._set_type(expression, exp.DataType.Type.VARCHAR)
|
||||
elif expression.is_int:
|
||||
expression.type = exp.DataType.Type.INT
|
||||
self._set_type(expression, exp.DataType.Type.INT)
|
||||
else:
|
||||
expression.type = exp.DataType.Type.DOUBLE
|
||||
self._set_type(expression, exp.DataType.Type.DOUBLE)
|
||||
|
||||
return expression
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
expression.type = target_type
|
||||
self._set_type(expression, target_type)
|
||||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
|
@ -394,17 +404,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for expr in expressions:
|
||||
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
|
||||
|
||||
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
|
||||
self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN)
|
||||
|
||||
if promote:
|
||||
if expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.type = exp.DataType.Type.BIGINT
|
||||
self._set_type(expression, exp.DataType.Type.BIGINT)
|
||||
elif expression.type.this in exp.DataType.FLOAT_TYPES:
|
||||
expression.type = exp.DataType.Type.DOUBLE
|
||||
self._set_type(expression, exp.DataType.Type.DOUBLE)
|
||||
|
||||
if array:
|
||||
expression.type = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
|
||||
self._set_type(
|
||||
expression,
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
|
||||
),
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -17,9 +17,11 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
exp.replace_children(expression, canonicalize)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = replace_date_funcs(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
expression = remove_ascending_order(expression)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -30,6 +32,14 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
|||
return node
|
||||
|
||||
|
||||
def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP)
|
||||
return node
|
||||
|
||||
|
||||
def coerce_type(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Binary):
|
||||
_coerce_date(node.left, node.right)
|
||||
|
@ -63,6 +73,14 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
|
||||
# Convert ORDER BY a ASC to ORDER BY a
|
||||
expression.set("desc", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if (
|
||||
|
@ -75,10 +93,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: str) -> None:
|
||||
data_type = exp.DataType.build(to)
|
||||
cast = exp.Cast(this=node.copy(), to=data_type)
|
||||
cast.type = data_type
|
||||
node.replace(cast)
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
|
|
|
@ -1,17 +1,22 @@
|
|||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
import typing as t
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import cached_generator
|
||||
from sqlglot.helper import first, while_changing
|
||||
from sqlglot.helper import first, merge_ranges, while_changing
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
|
||||
class UnsupportedUnit(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def simplify(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
@ -72,7 +77,9 @@ def simplify(expression):
|
|||
node = simplify_coalesce(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_equality(node)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_datetrunc_predicate(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -84,6 +91,21 @@ def simplify(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def catch(*exceptions):
|
||||
"""Decorator that ignores a simplification function if any of `exceptions` are raised"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapped(expression, *args, **kwargs):
|
||||
try:
|
||||
return func(expression, *args, **kwargs)
|
||||
except exceptions:
|
||||
return expression
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
||||
"""Rewrite x between y and z to x >= y AND x <= z.
|
||||
|
||||
|
@ -196,7 +218,7 @@ COMPARISONS = (
|
|||
exp.Is,
|
||||
)
|
||||
|
||||
INVERSE_COMPARISONS = {
|
||||
INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
||||
exp.LT: exp.GT,
|
||||
exp.GT: exp.LT,
|
||||
exp.LTE: exp.GTE,
|
||||
|
@ -347,6 +369,87 @@ def absorb_and_eliminate(expression, root=True):
|
|||
return expression
|
||||
|
||||
|
||||
INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
||||
exp.DateAdd: exp.Sub,
|
||||
exp.DateSub: exp.Add,
|
||||
exp.DatetimeAdd: exp.Sub,
|
||||
exp.DatetimeSub: exp.Add,
|
||||
}
|
||||
|
||||
INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
|
||||
**INVERSE_DATE_OPS,
|
||||
exp.Add: exp.Sub,
|
||||
exp.Sub: exp.Add,
|
||||
}
|
||||
|
||||
|
||||
def _is_number(expression: exp.Expression) -> bool:
|
||||
return expression.is_number
|
||||
|
||||
|
||||
def _is_date(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Cast) and extract_date(expression) is not None
|
||||
|
||||
|
||||
def _is_interval(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Use the subtraction and addition properties of equality to simplify expressions:
|
||||
|
||||
x + 1 = 3 becomes x = 2
|
||||
|
||||
There are two binary operations in the above expression: + and =
|
||||
Here's how we reference all the operands in the code below:
|
||||
|
||||
l r
|
||||
x + 1 = 3
|
||||
a b
|
||||
"""
|
||||
if isinstance(expression, COMPARISONS):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if l.__class__ in INVERSE_OPS:
|
||||
pass
|
||||
elif r.__class__ in INVERSE_OPS:
|
||||
l, r = r, l
|
||||
else:
|
||||
return expression
|
||||
|
||||
if r.is_number:
|
||||
a_predicate = _is_number
|
||||
b_predicate = _is_number
|
||||
elif _is_date(r):
|
||||
a_predicate = _is_date
|
||||
b_predicate = _is_interval
|
||||
else:
|
||||
return expression
|
||||
|
||||
if l.__class__ in INVERSE_DATE_OPS:
|
||||
a = l.this
|
||||
b = exp.Interval(
|
||||
this=l.expression.copy(),
|
||||
unit=l.unit.copy(),
|
||||
)
|
||||
else:
|
||||
a, b = l.left, l.right
|
||||
|
||||
if not a_predicate(a) and b_predicate(b):
|
||||
pass
|
||||
elif not a_predicate(b) and b_predicate(a):
|
||||
a, b = b, a
|
||||
else:
|
||||
return expression
|
||||
|
||||
return expression.__class__(
|
||||
this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def simplify_literals(expression, root=True):
|
||||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
|
@ -530,6 +633,123 @@ def simplify_concat(expression):
|
|||
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
|
||||
|
||||
|
||||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
||||
|
||||
def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
|
||||
"""
|
||||
Get the date range for a DATE_TRUNC equality comparison:
|
||||
|
||||
Example:
|
||||
_datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
|
||||
Returns:
|
||||
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
|
||||
"""
|
||||
floor = date_floor(date, unit)
|
||||
|
||||
if date != floor:
|
||||
# This will always be False, except for NULL values.
|
||||
return None
|
||||
|
||||
return floor, floor + interval(unit)
|
||||
|
||||
|
||||
def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
|
||||
"""Get the logical expression for a date range"""
|
||||
return exp.and_(
|
||||
left >= date_literal(drange[0]),
|
||||
left < date_literal(drange[1]),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
def _datetrunc_eq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return _datetrunc_eq_expression(left, drange)
|
||||
|
||||
|
||||
def _datetrunc_neq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
return exp.and_(
|
||||
left < date_literal(drange[0]),
|
||||
left >= date_literal(drange[1]),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
|
||||
]
|
||||
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
||||
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
|
||||
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
|
||||
exp.EQ: _datetrunc_eq,
|
||||
exp.NEQ: _datetrunc_neq,
|
||||
}
|
||||
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
||||
|
||||
|
||||
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
||||
return (
|
||||
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
|
||||
and isinstance(right, exp.Cast)
|
||||
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
|
||||
)
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
||||
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
|
||||
comparison = expression.__class__
|
||||
|
||||
if comparison not in DATETRUNC_COMPARISONS:
|
||||
return expression
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if _is_datetrunc_predicate(l, r):
|
||||
pass
|
||||
elif _is_datetrunc_predicate(r, l):
|
||||
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
|
||||
l, r = r, l
|
||||
else:
|
||||
return expression
|
||||
|
||||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
||||
if all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
unit = l.unit.name.lower()
|
||||
|
||||
ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
|
||||
if not ranges:
|
||||
return expression
|
||||
|
||||
ranges = merge_ranges(ranges)
|
||||
|
||||
return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
# CROSS joins result in an empty table if the right table is empty.
|
||||
# So we can only simplify certain types of joins to CROSS.
|
||||
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
|
||||
|
@ -603,25 +823,15 @@ def extract_date(cast):
|
|||
return None
|
||||
|
||||
|
||||
def extract_interval(interval):
|
||||
def extract_interval(expression):
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
|
||||
try:
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
return interval(unit, n)
|
||||
except (UnsupportedUnit, ModuleNotFoundError):
|
||||
return None
|
||||
|
||||
n = int(interval.name)
|
||||
unit = interval.text("unit").lower()
|
||||
|
||||
if unit == "year":
|
||||
return relativedelta(years=n)
|
||||
if unit == "month":
|
||||
return relativedelta(months=n)
|
||||
if unit == "week":
|
||||
return relativedelta(weeks=n)
|
||||
if unit == "day":
|
||||
return relativedelta(days=n)
|
||||
return None
|
||||
|
||||
|
||||
def date_literal(date):
|
||||
return exp.cast(
|
||||
|
@ -630,6 +840,61 @@ def date_literal(date):
|
|||
)
|
||||
|
||||
|
||||
def interval(unit: str, n: int = 1):
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
if unit == "year":
|
||||
return relativedelta(years=1 * n)
|
||||
if unit == "quarter":
|
||||
return relativedelta(months=3 * n)
|
||||
if unit == "month":
|
||||
return relativedelta(months=1 * n)
|
||||
if unit == "week":
|
||||
return relativedelta(weeks=1 * n)
|
||||
if unit == "day":
|
||||
return relativedelta(days=1 * n)
|
||||
if unit == "hour":
|
||||
return relativedelta(hours=1 * n)
|
||||
if unit == "minute":
|
||||
return relativedelta(minutes=1 * n)
|
||||
if unit == "second":
|
||||
return relativedelta(seconds=1 * n)
|
||||
|
||||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_floor(d: datetime.date, unit: str) -> datetime.date:
|
||||
if unit == "year":
|
||||
return d.replace(month=1, day=1)
|
||||
if unit == "quarter":
|
||||
if d.month <= 3:
|
||||
return d.replace(month=1, day=1)
|
||||
elif d.month <= 6:
|
||||
return d.replace(month=4, day=1)
|
||||
elif d.month <= 9:
|
||||
return d.replace(month=7, day=1)
|
||||
else:
|
||||
return d.replace(month=10, day=1)
|
||||
if unit == "month":
|
||||
return d.replace(month=d.month, day=1)
|
||||
if unit == "week":
|
||||
# Assuming week starts on Monday (0) and ends on Sunday (6)
|
||||
return d - datetime.timedelta(days=d.weekday())
|
||||
if unit == "day":
|
||||
return d
|
||||
|
||||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
|
||||
floor = date_floor(d, unit)
|
||||
|
||||
if floor == d:
|
||||
return d
|
||||
|
||||
return floor + interval(unit)
|
||||
|
||||
|
||||
def boolean_literal(condition):
|
||||
return exp.true() if condition else exp.false()
|
||||
|
||||
|
|
|
@ -43,7 +43,11 @@ def unnest(select, parent_select, next_alias_name):
|
|||
predicate = select.find_ancestor(exp.Condition)
|
||||
alias = next_alias_name()
|
||||
|
||||
if not predicate or parent_select is not predicate.parent_select:
|
||||
if (
|
||||
not predicate
|
||||
or parent_select is not predicate.parent_select
|
||||
or not parent_select.args.get("from")
|
||||
):
|
||||
return
|
||||
|
||||
# This subquery returns a scalar and can just be converted to a cross join
|
||||
|
|
|
@ -278,6 +278,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ISNULL,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.KEEP,
|
||||
TokenType.KILL,
|
||||
TokenType.LEFT,
|
||||
TokenType.LOAD,
|
||||
TokenType.MERGE,
|
||||
|
@ -285,6 +286,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NEXT,
|
||||
TokenType.OFFSET,
|
||||
TokenType.ORDINALITY,
|
||||
TokenType.OVERLAPS,
|
||||
TokenType.OVERWRITE,
|
||||
TokenType.PARTITION,
|
||||
TokenType.PERCENT,
|
||||
|
@ -316,6 +318,7 @@ class Parser(metaclass=_Parser):
|
|||
INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END}
|
||||
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||
TokenType.ANTI,
|
||||
TokenType.APPLY,
|
||||
TokenType.ASOF,
|
||||
TokenType.FULL,
|
||||
|
@ -324,6 +327,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NATURAL,
|
||||
TokenType.OFFSET,
|
||||
TokenType.RIGHT,
|
||||
TokenType.SEMI,
|
||||
TokenType.WINDOW,
|
||||
}
|
||||
|
||||
|
@ -541,6 +545,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DESCRIBE: lambda self: self._parse_describe(),
|
||||
TokenType.DROP: lambda self: self._parse_drop(),
|
||||
TokenType.INSERT: lambda self: self._parse_insert(),
|
||||
TokenType.KILL: lambda self: self._parse_kill(),
|
||||
TokenType.LOAD: lambda self: self._parse_load(),
|
||||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
|
||||
|
@ -856,6 +861,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT}
|
||||
|
||||
NULL_TOKENS = {TokenType.NULL}
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default
|
||||
|
@ -873,6 +880,9 @@ class Parser(metaclass=_Parser):
|
|||
# Whether or not the table sample clause expects CSV syntax
|
||||
TABLESAMPLE_CSV = False
|
||||
|
||||
# Whether or not the SET command needs a delimiter (e.g. "=") for assignments.
|
||||
SET_REQUIRES_ASSIGNMENT_DELIMITER = True
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1280,7 +1290,14 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
return_ = self._match_text_seq("RETURN")
|
||||
expression = self._parse_statement()
|
||||
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
# Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
|
||||
# # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
|
||||
expression = self._parse_string()
|
||||
extend_props(self._parse_properties())
|
||||
else:
|
||||
expression = self._parse_statement()
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
|
@ -1400,20 +1417,18 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_text_seq("SQL", "SECURITY"):
|
||||
return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
|
||||
|
||||
assignment = self._match_pair(
|
||||
TokenType.VAR, TokenType.EQ, advance=False
|
||||
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
||||
index = self._index
|
||||
key = self._parse_column()
|
||||
|
||||
if assignment:
|
||||
key = self._parse_var_or_string()
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(
|
||||
exp.Property,
|
||||
this=key,
|
||||
value=self._parse_column() or self._parse_var(any_token=True),
|
||||
)
|
||||
if not self._match(TokenType.EQ):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
return None
|
||||
return self.expression(
|
||||
exp.Property,
|
||||
this=key.to_dot() if isinstance(key, exp.Column) else key,
|
||||
value=self._parse_column() or self._parse_var(any_token=True),
|
||||
)
|
||||
|
||||
def _parse_stored(self) -> exp.FileFormatProperty:
|
||||
self._match(TokenType.ALIAS)
|
||||
|
@ -1818,6 +1833,15 @@ class Parser(metaclass=_Parser):
|
|||
ignore=ignore,
|
||||
)
|
||||
|
||||
def _parse_kill(self) -> exp.Kill:
|
||||
kind = exp.var(self._prev.text) if self._match_texts(("CONNECTION", "QUERY")) else None
|
||||
|
||||
return self.expression(
|
||||
exp.Kill,
|
||||
this=self._parse_primary(),
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
|
||||
conflict = self._match_text_seq("ON", "CONFLICT")
|
||||
duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
|
||||
|
@ -2459,7 +2483,7 @@ class Parser(metaclass=_Parser):
|
|||
index = self._parse_id_var()
|
||||
table = None
|
||||
|
||||
using = self._parse_field() if self._match(TokenType.USING) else None
|
||||
using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None
|
||||
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
columns = self._parse_wrapped_csv(self._parse_ordered)
|
||||
|
@ -2476,6 +2500,7 @@ class Parser(metaclass=_Parser):
|
|||
primary=primary,
|
||||
amp=amp,
|
||||
partition_by=self._parse_partition_by(),
|
||||
where=self._parse_where(),
|
||||
)
|
||||
|
||||
def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
|
@ -2634,25 +2659,27 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
expressions = self._parse_wrapped_csv(self._parse_type)
|
||||
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
|
||||
alias = self._parse_table_alias() if with_alias else None
|
||||
|
||||
if alias and self.UNNEST_COLUMN_ONLY:
|
||||
if alias.args.get("columns"):
|
||||
self.raise_error("Unexpected extra column alias in unnest.")
|
||||
if alias:
|
||||
if self.UNNEST_COLUMN_ONLY:
|
||||
if alias.args.get("columns"):
|
||||
self.raise_error("Unexpected extra column alias in unnest.")
|
||||
|
||||
alias.set("columns", [alias.this])
|
||||
alias.set("this", None)
|
||||
alias.set("columns", [alias.this])
|
||||
alias.set("this", None)
|
||||
|
||||
offset = None
|
||||
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
|
||||
columns = alias.args.get("columns") or []
|
||||
if offset and len(expressions) < len(columns):
|
||||
offset = columns.pop()
|
||||
|
||||
if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
|
||||
self._match(TokenType.ALIAS)
|
||||
offset = self._parse_id_var() or exp.to_identifier("offset")
|
||||
|
||||
return self.expression(
|
||||
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset
|
||||
)
|
||||
return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset)
|
||||
|
||||
def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
|
||||
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
|
||||
|
@ -2940,20 +2967,20 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_ordered(self) -> exp.Ordered:
|
||||
this = self._parse_conjunction()
|
||||
self._match(TokenType.ASC)
|
||||
|
||||
is_desc = self._match(TokenType.DESC)
|
||||
asc = self._match(TokenType.ASC)
|
||||
desc = self._match(TokenType.DESC) or (asc and False)
|
||||
|
||||
is_nulls_first = self._match_text_seq("NULLS", "FIRST")
|
||||
is_nulls_last = self._match_text_seq("NULLS", "LAST")
|
||||
desc = is_desc or False
|
||||
asc = not desc
|
||||
|
||||
nulls_first = is_nulls_first or False
|
||||
explicitly_null_ordered = is_nulls_first or is_nulls_last
|
||||
|
||||
if (
|
||||
not explicitly_null_ordered
|
||||
and (
|
||||
(asc and self.NULL_ORDERING == "nulls_are_small")
|
||||
(not desc and self.NULL_ORDERING == "nulls_are_small")
|
||||
or (desc and self.NULL_ORDERING != "nulls_are_small")
|
||||
)
|
||||
and self.NULL_ORDERING != "nulls_are_last"
|
||||
|
@ -3227,8 +3254,8 @@ class Parser(metaclass=_Parser):
|
|||
return self.UNARY_PARSERS[self._prev.token_type](self)
|
||||
return self._parse_at_time_zone(self._parse_type())
|
||||
|
||||
def _parse_type(self) -> t.Optional[exp.Expression]:
|
||||
interval = self._parse_interval()
|
||||
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
|
||||
interval = parse_interval and self._parse_interval()
|
||||
if interval:
|
||||
return interval
|
||||
|
||||
|
@ -3247,7 +3274,7 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_column()
|
||||
return self._parse_column_ops(data_type)
|
||||
|
||||
return this
|
||||
return this and self._parse_column_ops(this)
|
||||
|
||||
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
|
||||
this = self._parse_type()
|
||||
|
@ -3404,7 +3431,7 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_type() or self._parse_id_var()
|
||||
this = self._parse_type(parse_interval=False) or self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
return self._parse_column_def(this)
|
||||
|
||||
|
@ -3847,6 +3874,8 @@ class Parser(metaclass=_Parser):
|
|||
action = "NO ACTION"
|
||||
elif self._match_text_seq("CASCADE"):
|
||||
action = "CASCADE"
|
||||
elif self._match_text_seq("RESTRICT"):
|
||||
action = "RESTRICT"
|
||||
elif self._match_pair(TokenType.SET, TokenType.NULL):
|
||||
action = "SET NULL"
|
||||
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
||||
|
@ -4573,7 +4602,7 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_var() or self._parse_string()
|
||||
|
||||
def _parse_null(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.NULL):
|
||||
if self._match_set(self.NULL_TOKENS):
|
||||
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
|
@ -4608,14 +4637,18 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self._parse_wrapped_csv(self._parse_column)
|
||||
return self._parse_csv(self._parse_column)
|
||||
|
||||
except_column = self._parse_column()
|
||||
return [except_column] if except_column else None
|
||||
|
||||
def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.REPLACE):
|
||||
return None
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self._parse_wrapped_csv(self._parse_expression)
|
||||
return self._parse_expressions()
|
||||
|
||||
replace_expression = self._parse_expression()
|
||||
return [replace_expression] if replace_expression else None
|
||||
|
||||
def _parse_csv(
|
||||
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
|
||||
|
@ -4931,8 +4964,9 @@ class Parser(metaclass=_Parser):
|
|||
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
assignment_delimiter = self._match_texts(("=", "TO"))
|
||||
|
||||
if not self._match_texts(("=", "TO")):
|
||||
if not left or (self.SET_REQUIRES_ASSIGNMENT_DELIMITER and not assignment_delimiter):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
|
|
|
@ -247,6 +247,7 @@ class TokenType(AutoName):
|
|||
JOIN = auto()
|
||||
JOIN_MARKER = auto()
|
||||
KEEP = auto()
|
||||
KILL = auto()
|
||||
LANGUAGE = auto()
|
||||
LATERAL = auto()
|
||||
LEFT = auto()
|
||||
|
@ -595,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ISNULL": TokenType.ISNULL,
|
||||
"JOIN": TokenType.JOIN,
|
||||
"KEEP": TokenType.KEEP,
|
||||
"KILL": TokenType.KILL,
|
||||
"LATERAL": TokenType.LATERAL,
|
||||
"LEFT": TokenType.LEFT,
|
||||
"LIKE": TokenType.LIKE,
|
||||
|
|
|
@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
|
||||
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
|
||||
|
||||
expression.args["joins"].remove(join)
|
||||
|
||||
|
@ -163,65 +163,134 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
|
||||
if isinstance(expression, exp.Select):
|
||||
from sqlglot.optimizer.scope import Scope
|
||||
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
|
||||
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
||||
"""Convert explode/posexplode into unnest (used in hive -> presto)."""
|
||||
if isinstance(expression, exp.Select):
|
||||
from sqlglot.optimizer.scope import Scope
|
||||
|
||||
taken_select_names = set(expression.named_selects)
|
||||
taken_source_names = {name for name, _ in Scope(expression).references}
|
||||
taken_select_names = set(expression.named_selects)
|
||||
taken_source_names = {name for name, _ in Scope(expression).references}
|
||||
|
||||
for select in expression.selects:
|
||||
to_replace = select
|
||||
def new_name(names: t.Set[str], name: str) -> str:
|
||||
name = find_new_name(names, name)
|
||||
names.add(name)
|
||||
return name
|
||||
|
||||
pos_alias = ""
|
||||
explode_alias = ""
|
||||
arrays: t.List[exp.Condition] = []
|
||||
series_alias = new_name(taken_select_names, "pos")
|
||||
series = exp.alias_(
|
||||
exp.Unnest(
|
||||
expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
|
||||
),
|
||||
new_name(taken_source_names, "_u"),
|
||||
table=[series_alias],
|
||||
)
|
||||
|
||||
if isinstance(select, exp.Alias):
|
||||
explode_alias = select.alias
|
||||
select = select.this
|
||||
elif isinstance(select, exp.Aliases):
|
||||
pos_alias = select.aliases[0].name
|
||||
explode_alias = select.aliases[1].name
|
||||
select = select.this
|
||||
# we use list here because expression.selects is mutated inside the loop
|
||||
for select in expression.selects.copy():
|
||||
explode = select.find(exp.Explode, exp.Posexplode)
|
||||
|
||||
if isinstance(select, (exp.Explode, exp.Posexplode)):
|
||||
is_posexplode = isinstance(select, exp.Posexplode)
|
||||
if isinstance(explode, (exp.Explode, exp.Posexplode)):
|
||||
pos_alias = ""
|
||||
explode_alias = ""
|
||||
|
||||
explode_arg = select.this
|
||||
unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
|
||||
if isinstance(select, exp.Alias):
|
||||
explode_alias = select.alias
|
||||
alias = select
|
||||
elif isinstance(select, exp.Aliases):
|
||||
pos_alias = select.aliases[0].name
|
||||
explode_alias = select.aliases[1].name
|
||||
alias = select.replace(exp.alias_(select.this, "", copy=False))
|
||||
else:
|
||||
alias = select.replace(exp.alias_(select, ""))
|
||||
explode = alias.find(exp.Explode, exp.Posexplode)
|
||||
assert explode
|
||||
|
||||
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
|
||||
if isinstance(explode_arg, exp.Column):
|
||||
taken_select_names.add(explode_arg.output_name)
|
||||
is_posexplode = isinstance(explode, exp.Posexplode)
|
||||
explode_arg = explode.this
|
||||
|
||||
unnest_source_alias = find_new_name(taken_source_names, "_u")
|
||||
taken_source_names.add(unnest_source_alias)
|
||||
# This ensures that we won't use [POS]EXPLODE's argument as a new selection
|
||||
if isinstance(explode_arg, exp.Column):
|
||||
taken_select_names.add(explode_arg.output_name)
|
||||
|
||||
if not explode_alias:
|
||||
explode_alias = find_new_name(taken_select_names, "col")
|
||||
taken_select_names.add(explode_alias)
|
||||
unnest_source_alias = new_name(taken_source_names, "_u")
|
||||
|
||||
if not explode_alias:
|
||||
explode_alias = new_name(taken_select_names, "col")
|
||||
|
||||
if is_posexplode:
|
||||
pos_alias = new_name(taken_select_names, "pos")
|
||||
|
||||
if not pos_alias:
|
||||
pos_alias = new_name(taken_select_names, "pos")
|
||||
|
||||
alias.set("alias", exp.to_identifier(explode_alias))
|
||||
|
||||
column = exp.If(
|
||||
this=exp.column(series_alias).eq(exp.column(pos_alias)),
|
||||
true=exp.column(explode_alias),
|
||||
)
|
||||
|
||||
explode.replace(column)
|
||||
|
||||
if is_posexplode:
|
||||
pos_alias = find_new_name(taken_select_names, "pos")
|
||||
taken_select_names.add(pos_alias)
|
||||
expressions = expression.expressions
|
||||
expressions.insert(
|
||||
expressions.index(alias) + 1,
|
||||
exp.If(
|
||||
this=exp.column(series_alias).eq(exp.column(pos_alias)),
|
||||
true=exp.column(pos_alias),
|
||||
).as_(pos_alias),
|
||||
)
|
||||
expression.set("expressions", expressions)
|
||||
|
||||
if is_posexplode:
|
||||
column_names = [explode_alias, pos_alias]
|
||||
to_replace.pop()
|
||||
expression.select(pos_alias, explode_alias, copy=False)
|
||||
else:
|
||||
column_names = [explode_alias]
|
||||
to_replace.replace(exp.column(explode_alias))
|
||||
if not arrays:
|
||||
if expression.args.get("from"):
|
||||
expression.join(series, copy=False)
|
||||
else:
|
||||
expression.from_(series, copy=False)
|
||||
|
||||
unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
|
||||
size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
|
||||
arrays.append(size)
|
||||
|
||||
if not expression.args.get("from"):
|
||||
expression.from_(unnest, copy=False)
|
||||
else:
|
||||
expression.join(unnest, join_type="CROSS", copy=False)
|
||||
# trino doesn't support left join unnest with on conditions
|
||||
# if it did, this would be much simpler
|
||||
expression.join(
|
||||
exp.alias_(
|
||||
exp.Unnest(
|
||||
expressions=[explode_arg.copy()],
|
||||
offset=exp.to_identifier(pos_alias),
|
||||
),
|
||||
unnest_source_alias,
|
||||
table=[explode_alias],
|
||||
),
|
||||
join_type="CROSS",
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return expression
|
||||
if index_offset != 1:
|
||||
size = size - 1
|
||||
|
||||
expression.where(
|
||||
exp.column(series_alias)
|
||||
.eq(exp.column(pos_alias))
|
||||
.or_(
|
||||
(exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
if arrays:
|
||||
end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
|
||||
|
||||
if index_offset != 1:
|
||||
end = end - (1 - index_offset)
|
||||
series.expressions[0].set("end", end)
|
||||
|
||||
return expression
|
||||
|
||||
return _explode_to_unnest
|
||||
|
||||
|
||||
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
|
||||
|
@ -283,6 +352,31 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Timestamp) and not expression.expression:
|
||||
return exp.cast(
|
||||
expression.this,
|
||||
to=exp.DataType.Type.TIMESTAMP,
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Select):
|
||||
for join in expression.args.get("joins") or []:
|
||||
on = join.args.get("on")
|
||||
if on and join.kind in ("SEMI", "ANTI"):
|
||||
subquery = exp.select("1").from_(join.this).where(on)
|
||||
exists = exp.Exists(this=subquery)
|
||||
if join.kind == "ANTI":
|
||||
exists = exists.not_(copy=False)
|
||||
|
||||
join.pop()
|
||||
expression.where(exists, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
@ -327,12 +421,3 @@ def preprocess(
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Timestamp) and not expression.expression:
|
||||
return exp.cast(
|
||||
expression.this,
|
||||
to=exp.DataType.Type.TIMESTAMP,
|
||||
)
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue