Merging upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
63a75c51ff
commit
64041d1d66
85 changed files with 53899 additions and 50390 deletions
|
@ -18,6 +18,7 @@ def _timestamp_diff(
|
|||
|
||||
class Databricks(Spark):
|
||||
SAFE_DIVISION = False
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
@ -38,6 +39,8 @@ class Databricks(Spark):
|
|||
|
||||
class Generator(Spark.Generator):
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
|
|
|
@ -161,6 +161,9 @@ class _Dialect(type):
|
|||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
if enum not in ("", "athena", "presto", "trino"):
|
||||
klass.generator_class.TRY_SUPPORTED = False
|
||||
|
||||
if enum not in ("", "databricks", "hive", "spark", "spark2"):
|
||||
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
|
||||
for modifier in ("cluster", "distribute", "sort"):
|
||||
|
@ -318,6 +321,9 @@ class Dialect(metaclass=_Dialect):
|
|||
UNICODE_START: t.Optional[str] = None
|
||||
UNICODE_END: t.Optional[str] = None
|
||||
|
||||
# Separator of COPY statement parameters
|
||||
COPY_PARAMS_ARE_CSV = True
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
"""
|
||||
|
@ -897,9 +903,7 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
|
|||
|
||||
|
||||
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
|
||||
bad_args = list(
|
||||
filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
|
||||
)
|
||||
bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers")))
|
||||
if bad_args:
|
||||
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
|
||||
|
||||
|
|
|
@ -280,6 +280,8 @@ class DuckDB(Dialect):
|
|||
"RANGE": _build_generate_series(end_exclusive=True),
|
||||
}
|
||||
|
||||
FUNCTIONS.pop("DATE_SUB")
|
||||
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("DECODE")
|
||||
|
||||
|
@ -365,6 +367,7 @@ class DuckDB(Dialect):
|
|||
MULTI_ARG_DISTINCT = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
COPY_HAS_INTO_KEYWORD = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
|
|
@ -668,6 +668,7 @@ class MySQL(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ArrayAgg: rename_func("GROUP_CONCAT"),
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.DateDiff: _remove_ts_or_ds_to_date(
|
||||
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
|
||||
|
@ -766,15 +767,37 @@ class MySQL(Dialect):
|
|||
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
|
||||
CHAR_CAST_MAPPING = dict.fromkeys(
|
||||
(
|
||||
exp.DataType.Type.LONGTEXT,
|
||||
exp.DataType.Type.LONGBLOB,
|
||||
exp.DataType.Type.MEDIUMBLOB,
|
||||
exp.DataType.Type.MEDIUMTEXT,
|
||||
exp.DataType.Type.TEXT,
|
||||
exp.DataType.Type.TINYBLOB,
|
||||
exp.DataType.Type.TINYTEXT,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
),
|
||||
"CHAR",
|
||||
)
|
||||
SIGNED_CAST_MAPPING = dict.fromkeys(
|
||||
(
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.BOOLEAN,
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.SMALLINT,
|
||||
exp.DataType.Type.TINYINT,
|
||||
exp.DataType.Type.MEDIUMINT,
|
||||
),
|
||||
"SIGNED",
|
||||
)
|
||||
|
||||
# MySQL doesn't support many datatypes in cast.
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
|
||||
CAST_MAPPING = {
|
||||
exp.DataType.Type.BIGINT: "SIGNED",
|
||||
exp.DataType.Type.BOOLEAN: "SIGNED",
|
||||
exp.DataType.Type.INT: "SIGNED",
|
||||
exp.DataType.Type.TEXT: "CHAR",
|
||||
**CHAR_CAST_MAPPING,
|
||||
**SIGNED_CAST_MAPPING,
|
||||
exp.DataType.Type.UBIGINT: "UNSIGNED",
|
||||
exp.DataType.Type.VARCHAR: "CHAR",
|
||||
}
|
||||
|
||||
TIMESTAMP_FUNC_TYPES = {
|
||||
|
@ -782,6 +805,13 @@ class MySQL(Dialect):
|
|||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
}
|
||||
|
||||
def extract_sql(self, expression: exp.Extract) -> str:
|
||||
unit = expression.name
|
||||
if unit and unit.lower() == "epoch":
|
||||
return self.func("UNIX_TIMESTAMP", expression.expression)
|
||||
|
||||
return super().extract_sql(expression)
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html
|
||||
result = super().datatype_sql(expression)
|
||||
|
@ -867,3 +897,16 @@ class MySQL(Dialect):
|
|||
charset = expression.args.get("charset")
|
||||
using = f" USING {self.sql(charset)}" if charset else ""
|
||||
return f"CHAR({this}{using})"
|
||||
|
||||
def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str:
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
# Pick an old-enough date to avoid negative timestamp diffs
|
||||
start_ts = "'0000-01-01 00:00:00'"
|
||||
|
||||
# Source: https://stackoverflow.com/a/32955740
|
||||
timestamp_diff = build_date_delta(exp.TimestampDiff)([unit, start_ts, expression.this])
|
||||
interval = exp.Interval(this=timestamp_diff, unit=unit)
|
||||
dateadd = build_date_delta_with_interval(exp.DateAdd)([start_ts, interval])
|
||||
|
||||
return self.sql(dateadd)
|
||||
|
|
|
@ -32,7 +32,7 @@ from sqlglot.dialects.dialect import (
|
|||
trim_sql,
|
||||
ts_or_ds_add_cast,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -204,6 +204,29 @@ def _json_extract_sql(
|
|||
return _generate
|
||||
|
||||
|
||||
def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
||||
# The signature of REGEXP_REPLACE is:
|
||||
# regexp_replace(source, pattern, replacement [, start [, N ]] [, flags ])
|
||||
#
|
||||
# Any one of `start`, `N` and `flags` can be column references, meaning that
|
||||
# unless we can statically see that the last argument is a non-integer string
|
||||
# (eg. not '0'), then it's not possible to construct the correct AST
|
||||
if len(args) > 3:
|
||||
last = args[-1]
|
||||
if not is_int(last.name):
|
||||
if not last.type or last.is_type(exp.DataType.Type.UNKNOWN, exp.DataType.Type.NULL):
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
last = annotate_types(last)
|
||||
|
||||
if last.is_type(*exp.DataType.TEXT_TYPES):
|
||||
regexp_replace = exp.RegexpReplace.from_arg_list(args[:-1])
|
||||
regexp_replace.set("modifiers", last)
|
||||
return regexp_replace
|
||||
|
||||
return exp.RegexpReplace.from_arg_list(args)
|
||||
|
||||
|
||||
class Postgres(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
TYPED_DIVISION = True
|
||||
|
@ -266,24 +289,25 @@ class Postgres(Dialect):
|
|||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
"CONSTRAINT TRIGGER": TokenType.COMMAND,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"EXEC": TokenType.COMMAND,
|
||||
"HSTORE": TokenType.HSTORE,
|
||||
"INT8": TokenType.BIGINT,
|
||||
"JSONB": TokenType.JSONB,
|
||||
"MONEY": TokenType.MONEY,
|
||||
"NAME": TokenType.NAME,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"OPERATOR": TokenType.OPERATOR,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"REINDEX": TokenType.COMMAND,
|
||||
"RESET": TokenType.COMMAND,
|
||||
"REVOKE": TokenType.COMMAND,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"NAME": TokenType.NAME,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
"ONLY": TokenType.ONLY,
|
||||
"OPERATOR": TokenType.OPERATOR,
|
||||
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
|
||||
|
@ -320,6 +344,7 @@ class Postgres(Dialect):
|
|||
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
|
||||
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"REGEXP_REPLACE": _build_regexp_replace,
|
||||
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _build_to_timestamp,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
|
@ -417,6 +442,7 @@ class Postgres(Dialect):
|
|||
LIKE_PROPERTY_INSIDE_SCHEMA = True
|
||||
MULTI_ARG_DISTINCT = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
COPY_HAS_INTO_KEYWORD = False
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
|
@ -518,6 +544,7 @@ class Postgres(Dialect):
|
|||
exp.Variance: rename_func("VAR_SAMP"),
|
||||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
TRANSFORMS.pop(exp.CommentColumnConstraint)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -526,6 +553,14 @@ class Postgres(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def schemacommentproperty_sql(self, expression: exp.SchemaCommentProperty) -> str:
|
||||
self.unsupported("Table comments are not supported in the CREATE statement")
|
||||
return ""
|
||||
|
||||
def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
|
||||
self.unsupported("Column comments are not supported in the CREATE statement")
|
||||
return ""
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
if len(expression.expressions) == 1:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
|
|
@ -222,6 +222,8 @@ class Presto(Dialect):
|
|||
"ROW": TokenType.STRUCT,
|
||||
"IPADDRESS": TokenType.IPADDRESS,
|
||||
"IPPREFIX": TokenType.IPPREFIX,
|
||||
"TDIGEST": TokenType.TDIGEST,
|
||||
"HYPERLOGLOG": TokenType.HLLSKETCH,
|
||||
}
|
||||
|
||||
KEYWORDS.pop("QUALIFY")
|
||||
|
@ -316,6 +318,7 @@ class Presto(Dialect):
|
|||
exp.DataType.Type.STRUCT: "ROW",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.DATETIME64: "TIMESTAMP",
|
||||
exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -4,6 +4,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -53,6 +54,15 @@ class PRQL(Dialect):
|
|||
_select_all(self._parse_table()), distinct=False, copy=False
|
||||
),
|
||||
"SORT": lambda self, query: self._parse_order_by(query),
|
||||
"AGGREGATE": lambda self, query: self._parse_selection(
|
||||
query, parse_method=self._parse_aggregate, append=False
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"AVERAGE": exp.Avg.from_arg_list,
|
||||
"SUM": lambda args: exp.func("COALESCE", exp.Sum(this=seq_get(args, 0)), 0),
|
||||
}
|
||||
|
||||
def _parse_equality(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -87,14 +97,20 @@ class PRQL(Dialect):
|
|||
|
||||
return query
|
||||
|
||||
def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
|
||||
def _parse_selection(
|
||||
self,
|
||||
query: exp.Query,
|
||||
parse_method: t.Optional[t.Callable] = None,
|
||||
append: bool = True,
|
||||
) -> exp.Query:
|
||||
parse_method = parse_method if parse_method else self._parse_expression
|
||||
if self._match(TokenType.L_BRACE):
|
||||
selects = self._parse_csv(self._parse_expression)
|
||||
selects = self._parse_csv(parse_method)
|
||||
|
||||
if not self._match(TokenType.R_BRACE, expression=query):
|
||||
self.raise_error("Expecting }")
|
||||
else:
|
||||
expression = self._parse_expression()
|
||||
expression = parse_method()
|
||||
selects = [expression] if expression else []
|
||||
|
||||
projections = {
|
||||
|
@ -136,6 +152,24 @@ class PRQL(Dialect):
|
|||
self.raise_error("Expecting }")
|
||||
return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
|
||||
|
||||
def _parse_aggregate(self) -> t.Optional[exp.Expression]:
|
||||
alias = None
|
||||
if self._next and self._next.token_type == TokenType.ALIAS:
|
||||
alias = self._parse_id_var(any_token=True)
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
name = self._curr and self._curr.text.upper()
|
||||
func_builder = self.FUNCTIONS.get(name)
|
||||
if func_builder:
|
||||
self._advance()
|
||||
args = self._parse_column()
|
||||
func = func_builder([args])
|
||||
else:
|
||||
self.raise_error(f"Unsupported aggregation function {name}")
|
||||
if alias:
|
||||
return self.expression(exp.Alias, this=func, alias=alias)
|
||||
return func
|
||||
|
||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||
if self._next and self._next.token_type == TokenType.ALIAS:
|
||||
alias = self._parse_id_var(True)
|
||||
|
|
|
@ -38,6 +38,7 @@ class Redshift(Postgres):
|
|||
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
INDEX_OFFSET = 0
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
|
@ -138,6 +139,7 @@ class Redshift(Postgres):
|
|||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = False
|
||||
MULTI_ARG_DISTINCT = True
|
||||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
|
|
|
@ -289,6 +289,7 @@ class Snowflake(Dialect):
|
|||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
PREFER_CTE_ALIAS_COLUMN = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"YYYY": "%Y",
|
||||
|
@ -439,7 +440,7 @@ class Snowflake(Dialect):
|
|||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"LOCATION": lambda self: self._parse_location(),
|
||||
"LOCATION": lambda self: self._parse_location_property(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
|
@ -675,10 +676,13 @@ class Snowflake(Dialect):
|
|||
self._match_text_seq("WITH")
|
||||
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
|
||||
|
||||
def _parse_location(self) -> exp.LocationProperty:
|
||||
def _parse_location_property(self) -> exp.LocationProperty:
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.LocationProperty, this=self._parse_location_path())
|
||||
|
||||
def _parse_file_location(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_table_parts()
|
||||
|
||||
def _parse_location_path(self) -> exp.Var:
|
||||
parts = [self._advance_any(ignore_reserved=True)]
|
||||
|
||||
|
@ -715,10 +719,7 @@ class Snowflake(Dialect):
|
|||
"SQL_DOUBLE": TokenType.DOUBLE,
|
||||
"SQL_VARCHAR": TokenType.VARCHAR,
|
||||
"STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
}
|
||||
|
||||
|
@ -745,6 +746,8 @@ class Snowflake(Dialect):
|
|||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
INSERT_OVERWRITE = " OVERWRITE INTO"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -845,7 +848,6 @@ class Snowflake(Dialect):
|
|||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.NESTED: "OBJECT",
|
||||
exp.DataType.Type.STRUCT: "OBJECT",
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
STAR_MAPPING = {
|
||||
|
@ -1038,3 +1040,11 @@ class Snowflake(Dialect):
|
|||
values.append(e)
|
||||
|
||||
return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))
|
||||
|
||||
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
|
||||
option = self.sql(expression, "this")
|
||||
if option.upper() == "FILE_FORMAT":
|
||||
values = self.expressions(expression, key="expression", flat=True, sep=" ")
|
||||
return f"{option} = ({values})"
|
||||
|
||||
return super().copyparameter_sql(expression)
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func, unit_to_var
|
||||
from sqlglot.dialects.hive import _build_with_ignore_nulls
|
||||
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
|
||||
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
|
||||
from sqlglot.helper import ensure_list, seq_get
|
||||
from sqlglot.transforms import (
|
||||
ctas_with_tmp_tables_to_create_tmp_view,
|
||||
|
@ -63,6 +63,8 @@ class Spark(Spark2):
|
|||
**Spark2.Parser.FUNCTIONS,
|
||||
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
|
||||
"DATEDIFF": _build_datediff,
|
||||
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
|
||||
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
|
||||
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
|
||||
this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
|
||||
),
|
||||
|
@ -88,6 +90,8 @@ class Spark(Spark2):
|
|||
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
|
||||
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
|
||||
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
|
||||
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
|
||||
exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
|
|
|
@ -259,12 +259,15 @@ class Spark2(Hive):
|
|||
return Generator.struct_sql(self, expression)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if is_parse_json(expression.this):
|
||||
arg = expression.this
|
||||
is_json_extract = isinstance(arg, (exp.JSONExtract, exp.JSONExtractScalar))
|
||||
|
||||
if is_parse_json(arg) or is_json_extract:
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
return self.func("FROM_JSON", arg if is_json_extract else arg.this, schema)
|
||||
|
||||
if is_parse_json(expression):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
return self.func("TO_JSON", arg)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import merge_without_target_sql
|
||||
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql
|
||||
from sqlglot.dialects.presto import Presto
|
||||
|
||||
|
||||
|
@ -9,12 +9,19 @@ class Trino(Presto):
|
|||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
LOG_BASE_FIRST = True
|
||||
|
||||
class Parser(Presto.Parser):
|
||||
FUNCTION_PARSERS = {
|
||||
**Presto.Parser.FUNCTION_PARSERS,
|
||||
"TRIM": lambda self: self._parse_trim(),
|
||||
}
|
||||
|
||||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
exp.ArraySum: lambda self,
|
||||
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.Merge: merge_without_target_sql,
|
||||
exp.Trim: trim_sql,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
|
|
|
@ -728,6 +728,7 @@ class TSQL(Dialect):
|
|||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
OUTER_UNION_MODIFIERS = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
|
@ -912,7 +913,7 @@ class TSQL(Dialect):
|
|||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
):
|
||||
sql = f"#{sql}"
|
||||
sql = f"[#{sql[1:]}" if sql.startswith("[") else f"#{sql}"
|
||||
|
||||
return sql
|
||||
|
||||
|
|
|
@ -1955,6 +1955,31 @@ class Connect(Expression):
|
|||
arg_types = {"start": False, "connect": True, "nocycle": False}
|
||||
|
||||
|
||||
class CopyParameter(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class Copy(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"kind": True,
|
||||
"files": True,
|
||||
"credentials": False,
|
||||
"format": False,
|
||||
"params": False,
|
||||
}
|
||||
|
||||
|
||||
class Credentials(Expression):
|
||||
arg_types = {
|
||||
"credentials": False,
|
||||
"encryption": False,
|
||||
"storage": False,
|
||||
"iam_role": False,
|
||||
"region": False,
|
||||
}
|
||||
|
||||
|
||||
class Prior(Expression):
|
||||
pass
|
||||
|
||||
|
@ -2058,7 +2083,7 @@ class Insert(DDL, DML):
|
|||
"hint": False,
|
||||
"with": False,
|
||||
"is_function": False,
|
||||
"this": True,
|
||||
"this": False,
|
||||
"expression": False,
|
||||
"conflict": False,
|
||||
"returning": False,
|
||||
|
@ -3909,6 +3934,7 @@ class DataType(Expression):
|
|||
TIME = auto()
|
||||
TIMETZ = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPNTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMP_S = auto()
|
||||
|
@ -3936,6 +3962,7 @@ class DataType(Expression):
|
|||
VARIANT = auto()
|
||||
XML = auto()
|
||||
YEAR = auto()
|
||||
TDIGEST = auto()
|
||||
|
||||
STRUCT_TYPES = {
|
||||
Type.NESTED,
|
||||
|
@ -4010,6 +4037,7 @@ class DataType(Expression):
|
|||
Type.DATETIME64,
|
||||
Type.TIME,
|
||||
Type.TIMESTAMP,
|
||||
Type.TIMESTAMPNTZ,
|
||||
Type.TIMESTAMPLTZ,
|
||||
Type.TIMESTAMPTZ,
|
||||
Type.TIMESTAMP_MS,
|
||||
|
@ -4847,6 +4875,10 @@ class TryCast(Cast):
|
|||
pass
|
||||
|
||||
|
||||
class Try(Func):
|
||||
pass
|
||||
|
||||
|
||||
class CastToStrType(Func):
|
||||
arg_types = {"this": True, "to": True}
|
||||
|
||||
|
@ -5538,7 +5570,6 @@ class RegexpReplace(Func):
|
|||
"replacement": False,
|
||||
"position": False,
|
||||
"occurrence": False,
|
||||
"parameters": False,
|
||||
"modifiers": False,
|
||||
}
|
||||
|
||||
|
@ -6506,7 +6537,7 @@ def and_(
|
|||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
And: the new condition
|
||||
The new condition
|
||||
"""
|
||||
return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, **opts))
|
||||
|
||||
|
@ -6529,11 +6560,34 @@ def or_(
|
|||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Or: the new condition
|
||||
The new condition
|
||||
"""
|
||||
return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, **opts))
|
||||
|
||||
|
||||
def xor(
|
||||
*expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
|
||||
) -> Condition:
|
||||
"""
|
||||
Combine multiple conditions with an XOR logical operator.
|
||||
|
||||
Example:
|
||||
>>> xor("x=1", xor("y=1", "z=1")).sql()
|
||||
'x = 1 XOR (y = 1 XOR z = 1)'
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an Expression instance is passed, this is used as-is.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: whether to copy `expressions` (only applies to Expressions).
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
The new condition
|
||||
"""
|
||||
return t.cast(Condition, _combine(expressions, Xor, dialect, copy=copy, **opts))
|
||||
|
||||
|
||||
def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not:
|
||||
"""
|
||||
Wrap a condition with a NOT operator.
|
||||
|
|
|
@ -339,6 +339,18 @@ class Generator(metaclass=_Generator):
|
|||
# True means limit 1 happens after the union, False means it it happens on y.
|
||||
OUTER_UNION_MODIFIERS = True
|
||||
|
||||
# Whether parameters from COPY statement are wrapped in parentheses
|
||||
COPY_PARAMS_ARE_WRAPPED = True
|
||||
|
||||
# Whether values of params are set with "=" token or empty space
|
||||
COPY_PARAMS_EQ_REQUIRED = False
|
||||
|
||||
# Whether COPY statement has INTO keyword
|
||||
COPY_HAS_INTO_KEYWORD = True
|
||||
|
||||
# Whether the conditional TRY(expression) function is supported
|
||||
TRY_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -3158,6 +3170,13 @@ class Generator(metaclass=_Generator):
|
|||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="TRY_")
|
||||
|
||||
def try_sql(self, expression: exp.Try) -> str:
|
||||
if not self.TRY_SUPPORTED:
|
||||
self.unsupported("Unsupported TRY function")
|
||||
return self.sql(expression, "this")
|
||||
|
||||
return self.func("TRY", expression.this)
|
||||
|
||||
def log_sql(self, expression: exp.Log) -> str:
|
||||
this = expression.this
|
||||
expr = expression.expression
|
||||
|
@ -3334,9 +3353,10 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
then_expression = expression.args.get("then")
|
||||
if isinstance(then_expression, exp.Insert):
|
||||
then = f"INSERT {self.sql(then_expression, 'this')}"
|
||||
if "expression" in then_expression.args:
|
||||
then += f" VALUES {self.sql(then_expression, 'expression')}"
|
||||
this = self.sql(then_expression, "this")
|
||||
this = f"INSERT {this}" if this else "INSERT"
|
||||
then = self.sql(then_expression, "expression")
|
||||
then = f"{this} VALUES {then}" if then else this
|
||||
elif isinstance(then_expression, exp.Update):
|
||||
if isinstance(then_expression.args.get("expressions"), exp.Star):
|
||||
then = f"UPDATE {self.sql(then_expression, 'expressions')}"
|
||||
|
@ -3358,10 +3378,11 @@ class Generator(metaclass=_Generator):
|
|||
this = self.sql(table)
|
||||
using = f"USING {self.sql(expression, 'using')}"
|
||||
on = f"ON {self.sql(expression, 'on')}"
|
||||
expressions = self.expressions(expression, sep=" ")
|
||||
expressions = self.expressions(expression, sep=" ", indent=False)
|
||||
sep = self.sep()
|
||||
|
||||
return self.prepend_ctes(
|
||||
expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
|
||||
expression, f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}"
|
||||
)
|
||||
|
||||
def tochar_sql(self, expression: exp.ToChar) -> str:
|
||||
|
@ -3757,3 +3778,55 @@ class Generator(metaclass=_Generator):
|
|||
if self.pretty:
|
||||
return string.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return string
|
||||
|
||||
def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
|
||||
option = self.sql(expression, "this")
|
||||
value = self.sql(expression, "expression")
|
||||
|
||||
if not value:
|
||||
return option
|
||||
|
||||
op = " = " if self.COPY_PARAMS_EQ_REQUIRED else " "
|
||||
|
||||
return f"{option}{op}{value}"
|
||||
|
||||
def credentials_sql(self, expression: exp.Credentials) -> str:
|
||||
cred_expr = expression.args.get("credentials")
|
||||
if isinstance(cred_expr, exp.Literal):
|
||||
# Redshift case: CREDENTIALS <string>
|
||||
credentials = self.sql(expression, "credentials")
|
||||
credentials = f"CREDENTIALS {credentials}" if credentials else ""
|
||||
else:
|
||||
# Snowflake case: CREDENTIALS = (...)
|
||||
credentials = self.expressions(expression, key="credentials", flat=True, sep=" ")
|
||||
credentials = f"CREDENTIALS = ({credentials})" if credentials else ""
|
||||
|
||||
storage = self.sql(expression, "storage")
|
||||
storage = f" {storage}" if storage else ""
|
||||
|
||||
encryption = self.expressions(expression, key="encryption", flat=True, sep=" ")
|
||||
encryption = f" ENCRYPTION = ({encryption})" if encryption else ""
|
||||
|
||||
iam_role = self.sql(expression, "iam_role")
|
||||
iam_role = f"IAM_ROLE {iam_role}" if iam_role else ""
|
||||
|
||||
region = self.sql(expression, "region")
|
||||
region = f" REGION {region}" if region else ""
|
||||
|
||||
return f"{credentials}{storage}{encryption}{iam_role}{region}"
|
||||
|
||||
def copy_sql(self, expression: exp.Copy) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" INTO {this}" if self.COPY_HAS_INTO_KEYWORD else f" {this}"
|
||||
|
||||
credentials = self.sql(expression, "credentials")
|
||||
credentials = f" {credentials}" if credentials else ""
|
||||
kind = " FROM " if expression.args.get("kind") else " TO "
|
||||
files = self.expressions(expression, key="files", flat=True)
|
||||
|
||||
sep = ", " if self.dialect.COPY_PARAMS_ARE_CSV else " "
|
||||
params = self.expressions(expression, key="params", flat=True, sep=sep)
|
||||
if params:
|
||||
params = f" WITH ({params})" if self.COPY_PARAMS_ARE_WRAPPED else f" {params}"
|
||||
|
||||
return f"COPY{this}{kind}{files}{credentials}{params}"
|
||||
|
|
|
@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|||
from sqlglot import Schema, exp, maybe_parse
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
|
||||
from sqlglot.optimizer.scope import ScopeType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
@ -129,12 +130,6 @@ def to_node(
|
|||
reference_node_name: t.Optional[str] = None,
|
||||
trim_selects: bool = True,
|
||||
) -> Node:
|
||||
source_names = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
for dt in scope.derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
|
||||
# Find the specific select clause that is the source of the column we want.
|
||||
# This can either be a specific, named select or a generic `*` clause.
|
||||
select = (
|
||||
|
@ -242,13 +237,31 @@ def to_node(
|
|||
# If the source is a UDTF find columns used in the UTDF to generate the table
|
||||
if isinstance(source, exp.UDTF):
|
||||
source_columns |= set(source.find_all(exp.Column))
|
||||
derived_tables = [
|
||||
source.expression.parent
|
||||
for source in scope.sources.values()
|
||||
if isinstance(source, Scope) and source.is_derived_table
|
||||
]
|
||||
else:
|
||||
derived_tables = scope.derived_tables
|
||||
|
||||
source_names = {
|
||||
dt.alias: dt.comments[0].split()[1]
|
||||
for dt in derived_tables
|
||||
if dt.comments and dt.comments[0].startswith("source: ")
|
||||
}
|
||||
|
||||
for c in source_columns:
|
||||
table = c.table
|
||||
source = scope.sources.get(table)
|
||||
|
||||
if isinstance(source, Scope):
|
||||
selected_node, _ = scope.selected_sources.get(table, (None, None))
|
||||
reference_node_name = None
|
||||
if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names:
|
||||
reference_node_name = table
|
||||
elif source.scope_type == ScopeType.CTE:
|
||||
selected_node, _ = scope.selected_sources.get(table, (None, None))
|
||||
reference_node_name = selected_node.name if selected_node else None
|
||||
# The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
|
||||
to_node(
|
||||
c.name,
|
||||
|
@ -257,7 +270,7 @@ def to_node(
|
|||
scope_name=table,
|
||||
upstream=node,
|
||||
source_name=source_names.get(table) or source_name,
|
||||
reference_node_name=selected_node.name if selected_node else None,
|
||||
reference_node_name=reference_node_name,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# ruff: noqa: F401
|
||||
|
||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.optimizer import RULES as RULES, optimize as optimize
|
||||
from sqlglot.optimizer.scope import (
|
||||
Scope,
|
||||
build_scope,
|
||||
find_all_in_scope,
|
||||
find_in_scope,
|
||||
traverse_scope,
|
||||
walk_in_scope,
|
||||
Scope as Scope,
|
||||
build_scope as build_scope,
|
||||
find_all_in_scope as find_all_in_scope,
|
||||
find_in_scope as find_in_scope,
|
||||
traverse_scope as traverse_scope,
|
||||
walk_in_scope as walk_in_scope,
|
||||
)
|
||||
|
|
|
@ -89,7 +89,8 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(new_cte)
|
||||
|
||||
if new_ctes:
|
||||
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
query = expression.expression if isinstance(expression, exp.DDL) else expression
|
||||
query.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from sqlglot.helper import ensure_collection, find_new_name, seq_get
|
|||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
TRAVERSABLES = (exp.Query, exp.DDL, exp.DML)
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
ROOT = auto()
|
||||
|
@ -495,25 +497,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
|
|||
Returns:
|
||||
A list of the created scope instances
|
||||
"""
|
||||
if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query):
|
||||
# We ignore the DDL expression and build a scope for its query instead
|
||||
ddl_with = expression.args.get("with")
|
||||
expression = expression.expression
|
||||
|
||||
# If the DDL has CTEs attached, we need to add them to the query, or
|
||||
# prepend them if the query itself already has CTEs attached to it
|
||||
if ddl_with:
|
||||
ddl_with.pop()
|
||||
query_ctes = expression.ctes
|
||||
if not query_ctes:
|
||||
expression.set("with", ddl_with)
|
||||
else:
|
||||
expression.args["with"].set("recursive", ddl_with.recursive)
|
||||
expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes])
|
||||
|
||||
if isinstance(expression, exp.Query):
|
||||
if isinstance(expression, TRAVERSABLES):
|
||||
return list(_traverse_scope(Scope(expression)))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
|
@ -531,25 +516,37 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
|
|||
|
||||
|
||||
def _traverse_scope(scope):
|
||||
if isinstance(scope.expression, exp.Select):
|
||||
expression = scope.expression
|
||||
|
||||
if isinstance(expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(scope.expression, exp.Union):
|
||||
elif isinstance(expression, exp.Union):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
elif isinstance(expression, exp.Subquery):
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
else:
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
elif isinstance(expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
elif isinstance(expression, exp.UDTF):
|
||||
yield from _traverse_udtfs(scope)
|
||||
elif isinstance(expression, exp.DDL):
|
||||
if isinstance(expression.expression, exp.Query):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources))
|
||||
return
|
||||
elif isinstance(expression, exp.DML):
|
||||
yield from _traverse_ctes(scope)
|
||||
for query in find_all_in_scope(expression, exp.Query):
|
||||
# This check ensures we don't yield the CTE queries twice
|
||||
if not isinstance(query.parent, exp.CTE):
|
||||
yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources))
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
|
||||
)
|
||||
logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression))
|
||||
return
|
||||
|
||||
yield scope
|
||||
|
@ -749,7 +746,7 @@ def _traverse_udtfs(scope):
|
|||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
expression,
|
||||
scope_type=ScopeType.DERIVED_TABLE,
|
||||
scope_type=ScopeType.SUBQUERY,
|
||||
outer_columns=expression.alias_column_names,
|
||||
)
|
||||
):
|
||||
|
@ -757,8 +754,7 @@ def _traverse_udtfs(scope):
|
|||
top = child_scope
|
||||
sources[expression.alias] = child_scope
|
||||
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.table_scopes.append(top)
|
||||
scope.subquery_scopes.append(top)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
|
|
@ -224,6 +224,8 @@ def flatten(expression):
|
|||
def simplify_connectors(expression, root=True):
|
||||
def _simplify_connectors(expression, left, right):
|
||||
if left == right:
|
||||
if isinstance(expression, exp.Xor):
|
||||
return exp.false()
|
||||
return left
|
||||
if isinstance(expression, exp.And):
|
||||
if is_false(left) or is_false(right):
|
||||
|
@ -365,10 +367,17 @@ def uniq_sort(expression, root=True):
|
|||
C AND A AND B AND B -> A AND B AND C
|
||||
"""
|
||||
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
flattened = tuple(expression.flatten())
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
if isinstance(expression, exp.Xor):
|
||||
result_func = exp.xor
|
||||
# Do not deduplicate XOR as A XOR A != A if A == True
|
||||
deduped = None
|
||||
arr = tuple((gen(e), e) for e in flattened)
|
||||
else:
|
||||
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
|
||||
deduped = {gen(e): e for e in flattened}
|
||||
arr = tuple(deduped.items())
|
||||
|
||||
# check if the operands are already sorted, if not sort them
|
||||
# A AND C AND B -> A AND B AND C
|
||||
|
@ -378,7 +387,7 @@ def uniq_sort(expression, root=True):
|
|||
break
|
||||
else:
|
||||
# we didn't have to sort but maybe we need to dedup
|
||||
if len(deduped) < len(flattened):
|
||||
if deduped and len(deduped) < len(flattened):
|
||||
expression = result_func(*deduped.values(), copy=False)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -217,6 +217,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TIMESTAMP_NS,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.TIMESTAMPLTZ,
|
||||
TokenType.TIMESTAMPNTZ,
|
||||
TokenType.DATETIME,
|
||||
TokenType.DATETIME64,
|
||||
TokenType.DATE,
|
||||
|
@ -265,6 +266,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.UNKNOWN,
|
||||
TokenType.NULL,
|
||||
TokenType.NAME,
|
||||
TokenType.TDIGEST,
|
||||
*ENUM_TYPE_TOKENS,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
*AGGREGATE_TYPE_TOKENS,
|
||||
|
@ -329,6 +331,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.COMMENT,
|
||||
TokenType.COMMIT,
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.COPY,
|
||||
TokenType.DEFAULT,
|
||||
TokenType.DELETE,
|
||||
TokenType.DESC,
|
||||
|
@ -597,7 +600,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Condition: lambda self: self._parse_conjunction(),
|
||||
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
|
||||
exp.Expression: lambda self: self._parse_expression(),
|
||||
exp.From: lambda self: self._parse_from(),
|
||||
exp.From: lambda self: self._parse_from(joins=True),
|
||||
exp.Group: lambda self: self._parse_group(),
|
||||
exp.Having: lambda self: self._parse_having(),
|
||||
exp.Identifier: lambda self: self._parse_id_var(),
|
||||
|
@ -627,6 +630,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||
TokenType.COMMENT: lambda self: self._parse_comment(),
|
||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.COPY: lambda self: self._parse_copy(),
|
||||
TokenType.CREATE: lambda self: self._parse_create(),
|
||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||
TokenType.DESC: lambda self: self._parse_describe(),
|
||||
|
@ -1585,7 +1589,15 @@ class Parser(metaclass=_Parser):
|
|||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
elif create_token.token_type == TokenType.INDEX:
|
||||
this = self._parse_index(index=self._parse_id_var())
|
||||
# Postgres allows anonymous indexes, eg. CREATE INDEX IF NOT EXISTS ON t(c)
|
||||
if not self._match(TokenType.ON):
|
||||
index = self._parse_id_var()
|
||||
anonymous = False
|
||||
else:
|
||||
index = None
|
||||
anonymous = True
|
||||
|
||||
this = self._parse_index(index=index, anonymous=anonymous)
|
||||
elif create_token.token_type in self.DB_CREATABLES:
|
||||
table_parts = self._parse_table_parts(
|
||||
schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
|
||||
|
@ -1764,14 +1776,18 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
)
|
||||
|
||||
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
def _parse_unquoted_field(self):
|
||||
field = self._parse_field()
|
||||
if isinstance(field, exp.Identifier) and not field.quoted:
|
||||
field = exp.var(field)
|
||||
|
||||
return self.expression(exp_class, this=field, **kwargs)
|
||||
return field
|
||||
|
||||
def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
return self.expression(exp_class, this=self._parse_unquoted_field(), **kwargs)
|
||||
|
||||
def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
|
||||
properties = []
|
||||
|
@ -2206,9 +2222,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_describe(self) -> exp.Describe:
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper()
|
||||
if not self._match_set(self.ID_VAR_TOKENS, advance=False):
|
||||
if self._match(TokenType.DOT):
|
||||
style = None
|
||||
self._retreat(self._index - 1)
|
||||
self._retreat(self._index - 2)
|
||||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
expressions = properties.expressions if properties else None
|
||||
|
@ -2461,14 +2477,17 @@ class Parser(metaclass=_Parser):
|
|||
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction)
|
||||
)
|
||||
|
||||
def _parse_value(self) -> exp.Tuple:
|
||||
def _parse_value(self) -> t.Optional[exp.Tuple]:
|
||||
if self._match(TokenType.L_PAREN):
|
||||
expressions = self._parse_csv(self._parse_expression)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Tuple, expressions=expressions)
|
||||
|
||||
# In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows.
|
||||
return self.expression(exp.Tuple, expressions=[self._parse_expression()])
|
||||
expression = self._parse_expression()
|
||||
if expression:
|
||||
return self.expression(exp.Tuple, expressions=[expression])
|
||||
return None
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
return self._parse_expressions()
|
||||
|
@ -3010,10 +3029,9 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
|
||||
def _parse_index(
|
||||
self,
|
||||
index: t.Optional[exp.Expression] = None,
|
||||
self, index: t.Optional[exp.Expression] = None, anonymous: bool = False
|
||||
) -> t.Optional[exp.Index]:
|
||||
if index:
|
||||
if index or anonymous:
|
||||
unique = None
|
||||
primary = None
|
||||
amp = None
|
||||
|
@ -4305,7 +4323,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = self._parse_query_modifiers(seq_get(expressions, 0))
|
||||
|
||||
if isinstance(this, exp.UNWRAPPED_QUERIES):
|
||||
if not this and self._match(TokenType.R_PAREN, advance=False):
|
||||
this = self.expression(exp.Tuple)
|
||||
elif isinstance(this, exp.UNWRAPPED_QUERIES):
|
||||
this = self._parse_set_operations(
|
||||
self._parse_subquery(this=this, parse_alias=False)
|
||||
)
|
||||
|
@ -4675,7 +4695,7 @@ class Parser(metaclass=_Parser):
|
|||
this.set("cycle", False)
|
||||
|
||||
if not identity:
|
||||
this.set("expression", self._parse_bitwise())
|
||||
this.set("expression", self._parse_range())
|
||||
elif not this.args.get("start") and self._match(TokenType.NUMBER, advance=False):
|
||||
args = self._parse_csv(self._parse_bitwise)
|
||||
this.set("start", seq_get(args, 0))
|
||||
|
@ -5309,8 +5329,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match(TokenType.FROM):
|
||||
args.append(self._parse_bitwise())
|
||||
if self._match(TokenType.FOR):
|
||||
args.append(self._parse_bitwise())
|
||||
if self._match(TokenType.FOR):
|
||||
if len(args) == 1:
|
||||
args.append(exp.Literal.number(1))
|
||||
args.append(self._parse_bitwise())
|
||||
|
||||
return self.validate_expression(exp.Substring.from_arg_list(args), args)
|
||||
|
||||
|
@ -6292,3 +6314,97 @@ class Parser(metaclass=_Parser):
|
|||
op = self._parse_var(any_token=True)
|
||||
|
||||
return self.expression(exp.WithOperator, this=this, op=op)
|
||||
|
||||
def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
opts = []
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.L_PAREN)
|
||||
while self._curr and not self._match(TokenType.R_PAREN):
|
||||
opts.append(self._parse_conjunction())
|
||||
self._match(TokenType.COMMA)
|
||||
return opts
|
||||
|
||||
def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
|
||||
sep = TokenType.COMMA if self.dialect.COPY_PARAMS_ARE_CSV else None
|
||||
|
||||
options = []
|
||||
while self._curr and not self._match(TokenType.R_PAREN, advance=False):
|
||||
option = self._parse_unquoted_field()
|
||||
value = None
|
||||
|
||||
# Some options are defined as functions with the values as params
|
||||
if not isinstance(option, exp.Func):
|
||||
prev = self._prev.text.upper()
|
||||
# Different dialects might separate options and values by white space, "=" and "AS"
|
||||
self._match(TokenType.EQ)
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN):
|
||||
# Snowflake FILE_FORMAT case
|
||||
value = self._parse_wrapped_options()
|
||||
else:
|
||||
value = self._parse_unquoted_field()
|
||||
|
||||
param = self.expression(exp.CopyParameter, this=option, expression=value)
|
||||
options.append(param)
|
||||
|
||||
if sep:
|
||||
self._match(sep)
|
||||
|
||||
return options
|
||||
|
||||
def _parse_credentials(self) -> t.Optional[exp.Credentials]:
|
||||
expr = self.expression(exp.Credentials)
|
||||
|
||||
if self._match_text_seq("STORAGE_INTEGRATION", advance=False):
|
||||
expr.set("storage", self._parse_conjunction())
|
||||
if self._match_text_seq("CREDENTIALS"):
|
||||
# Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string>
|
||||
creds = (
|
||||
self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field()
|
||||
)
|
||||
expr.set("credentials", creds)
|
||||
if self._match_text_seq("ENCRYPTION"):
|
||||
expr.set("encryption", self._parse_wrapped_options())
|
||||
if self._match_text_seq("IAM_ROLE"):
|
||||
expr.set("iam_role", self._parse_field())
|
||||
if self._match_text_seq("REGION"):
|
||||
expr.set("region", self._parse_field())
|
||||
|
||||
return expr
|
||||
|
||||
def _parse_file_location(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_field()
|
||||
|
||||
def _parse_copy(self) -> exp.Copy | exp.Command:
|
||||
start = self._prev
|
||||
|
||||
self._match(TokenType.INTO)
|
||||
|
||||
this = (
|
||||
self._parse_conjunction()
|
||||
if self._match(TokenType.L_PAREN, advance=False)
|
||||
else self._parse_table(schema=True)
|
||||
)
|
||||
|
||||
kind = self._match(TokenType.FROM) or not self._match_text_seq("TO")
|
||||
|
||||
files = self._parse_csv(self._parse_file_location)
|
||||
credentials = self._parse_credentials()
|
||||
|
||||
self._match_text_seq("WITH")
|
||||
|
||||
params = self._parse_wrapped(self._parse_copy_parameters, optional=True)
|
||||
|
||||
# Fallback case
|
||||
if self._curr:
|
||||
return self._parse_as_command(start)
|
||||
|
||||
return self.expression(
|
||||
exp.Copy,
|
||||
this=this,
|
||||
kind=kind,
|
||||
credentials=credentials,
|
||||
files=files,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -145,6 +145,7 @@ class TokenType(AutoName):
|
|||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
TIMESTAMPNTZ = auto()
|
||||
TIMESTAMP_S = auto()
|
||||
TIMESTAMP_MS = auto()
|
||||
TIMESTAMP_NS = auto()
|
||||
|
@ -197,6 +198,7 @@ class TokenType(AutoName):
|
|||
NESTED = auto()
|
||||
AGGREGATEFUNCTION = auto()
|
||||
SIMPLEAGGREGATEFUNCTION = auto()
|
||||
TDIGEST = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
# keywords
|
||||
|
@ -223,6 +225,7 @@ class TokenType(AutoName):
|
|||
COMMIT = auto()
|
||||
CONNECT_BY = auto()
|
||||
CONSTRAINT = auto()
|
||||
COPY = auto()
|
||||
CREATE = auto()
|
||||
CROSS = auto()
|
||||
CUBE = auto()
|
||||
|
@ -647,6 +650,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"COMMIT": TokenType.COMMIT,
|
||||
"CONNECT BY": TokenType.CONNECT_BY,
|
||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||
"COPY": TokenType.COPY,
|
||||
"CREATE": TokenType.CREATE,
|
||||
"CROSS": TokenType.CROSS,
|
||||
"CUBE": TokenType.CUBE,
|
||||
|
@ -845,6 +849,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||
"TIMESTAMPNTZ": TokenType.TIMESTAMPNTZ,
|
||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMPNTZ,
|
||||
"DATE": TokenType.DATE,
|
||||
"DATETIME": TokenType.DATETIME,
|
||||
"INT4RANGE": TokenType.INT4RANGE,
|
||||
|
@ -867,7 +874,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ANALYZE": TokenType.COMMAND,
|
||||
"CALL": TokenType.COMMAND,
|
||||
"COMMENT": TokenType.COMMENT,
|
||||
"COPY": TokenType.COMMAND,
|
||||
"EXPLAIN": TokenType.COMMAND,
|
||||
"GRANT": TokenType.COMMAND,
|
||||
"OPTIMIZE": TokenType.COMMAND,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue