Merging upstream version 18.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ad94fdbf21
commit
11b24b93ea
67 changed files with 32690 additions and 32450 deletions
|
@ -9,6 +9,7 @@ from sqlglot._typing import E
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
binary_from_function,
|
||||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
|
@ -28,19 +29,6 @@ from sqlglot.tokens import TokenType
|
|||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
|
||||
def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
|
||||
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
|
||||
if not expression.find_ancestor(exp.From, exp.Join):
|
||||
return self.values_sql(expression)
|
||||
|
@ -187,6 +175,7 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
|
|||
|
||||
class BigQuery(Dialect):
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
@ -278,8 +267,6 @@ class BigQuery(Dialect):
|
|||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE": _parse_date,
|
||||
|
@ -436,13 +423,13 @@ class BigQuery(Dialect):
|
|||
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
|
||||
exp.Create: _create_sql,
|
||||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateFromParts: rename_func("DATE"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
|
||||
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
|
||||
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
|
||||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
|
@ -484,13 +471,13 @@ class BigQuery(Dialect):
|
|||
exp.StrToTime: lambda self, e: self.func(
|
||||
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
|
||||
),
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
|
||||
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
|
@ -640,13 +627,6 @@ class BigQuery(Dialect):
|
|||
|
||||
return super().attimezone_sql(expression)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals
|
||||
if expression.is_type("json"):
|
||||
return f"JSON {self.sql(expression, 'this')}"
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="SAFE_")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ class ClickHouse(Dialect):
|
|||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
|
@ -64,8 +65,6 @@ class ClickHouse(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ANY": exp.AnyValue.from_arg_list,
|
||||
|
|
|
@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not CONCAT's arguments must be strings
|
||||
STRICT_STRING_CONCAT = False
|
||||
|
||||
# Determines whether or not user-defined data types are supported
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
|
||||
# Determines how function names are going to be normalized
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
|
||||
|
@ -546,6 +549,19 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
|||
return exp.TimestampTrunc(this=this, unit=unit)
|
||||
|
||||
|
||||
def date_add_interval_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def func(self: Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
interval = exp.Interval(this=expression.expression.copy(), unit=unit)
|
||||
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
||||
return self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
|
||||
|
@ -736,5 +752,15 @@ def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
|
|||
|
||||
|
||||
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
|
||||
def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
|
||||
def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
|
||||
|
||||
|
||||
def is_parse_json(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.ParseJSON) or (
|
||||
isinstance(expression, exp.Cast) and expression.is_type("json")
|
||||
)
|
||||
|
||||
|
||||
def isnull_to_is_null(args: t.List) -> exp.Expression:
|
||||
return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null()))
|
||||
|
|
|
@ -39,6 +39,7 @@ class Drill(Dialect):
|
|||
DATE_FORMAT = "'yyyy-MM-dd'"
|
||||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
|
@ -80,7 +81,6 @@ class Drill(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
|
|
@ -105,6 +105,7 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
|
|||
|
||||
class DuckDB(Dialect):
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
@ -135,7 +136,6 @@ class DuckDB(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE,
|
||||
|
@ -158,6 +158,11 @@ class DuckDB(Dialect):
|
|||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
),
|
||||
"QUANTILE_CONT": exp.PercentileCont.from_arg_list,
|
||||
"QUANTILE_DISC": exp.PercentileDisc.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
|
@ -266,6 +271,9 @@ class DuckDB(Dialect):
|
|||
exp.cast(e.expression, "timestamp", copy=True),
|
||||
exp.cast(e.this, "timestamp", copy=True),
|
||||
),
|
||||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: rename_func("QUANTILE_CONT"),
|
||||
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpReplace: regexp_replace_sql,
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
is_parse_json,
|
||||
left_to_substring_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
|
@ -89,7 +90,7 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
|
|||
|
||||
def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
|
||||
if is_parse_json(this) and this.this.is_string:
|
||||
# Since FROM_JSON requires a nested type, we always wrap the json string with
|
||||
# an array to ensure that "naked" strings like "'a'" will be handled correctly
|
||||
wrapped_json = exp.Literal.string(f"[{this.this.name}]")
|
||||
|
@ -150,6 +151,7 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
|
|||
class Hive(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
@ -222,7 +224,6 @@ class Hive(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
|
|
@ -6,8 +6,10 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
isnull_to_is_null,
|
||||
json_keyvalue_comma_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
|
@ -99,6 +101,7 @@ class MySQL(Dialect):
|
|||
|
||||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
TIME_MAPPING = {
|
||||
|
@ -129,6 +132,7 @@ class MySQL(Dialect):
|
|||
"ENUM": TokenType.ENUM,
|
||||
"FORCE": TokenType.FORCE,
|
||||
"IGNORE": TokenType.IGNORE,
|
||||
"LOCK TABLES": TokenType.COMMAND,
|
||||
"LONGBLOB": TokenType.LONGBLOB,
|
||||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
|
@ -141,6 +145,7 @@ class MySQL(Dialect):
|
|||
"START": TokenType.BEGIN,
|
||||
"SIGNED": TokenType.BIGINT,
|
||||
"SIGNED INTEGER": TokenType.BIGINT,
|
||||
"UNLOCK TABLES": TokenType.COMMAND,
|
||||
"UNSIGNED": TokenType.UBIGINT,
|
||||
"UNSIGNED INTEGER": TokenType.UBIGINT,
|
||||
"YEAR": TokenType.YEAR,
|
||||
|
@ -193,8 +198,6 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.DATABASE,
|
||||
|
@ -233,7 +236,12 @@ class MySQL(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"ISNULL": isnull_to_is_null,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"MONTHNAME": lambda args: exp.TimeToStr(
|
||||
this=seq_get(args, 0),
|
||||
format=exp.Literal.string("%B"),
|
||||
),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
}
|
||||
|
||||
|
@ -374,7 +382,7 @@ class MySQL(Dialect):
|
|||
self._match_texts({"INDEX", "KEY"})
|
||||
|
||||
this = self._parse_id_var(any_token=False)
|
||||
type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text
|
||||
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
|
||||
schema = self._parse_schema()
|
||||
|
||||
options = []
|
||||
|
@ -414,7 +422,7 @@ class MySQL(Dialect):
|
|||
this=this,
|
||||
schema=schema,
|
||||
kind=kind,
|
||||
type=type_,
|
||||
index_type=index_type,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
@ -558,6 +566,8 @@ class MySQL(Dialect):
|
|||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
|
||||
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
|
|
|
@ -32,6 +32,7 @@ def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
|
|||
|
||||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
|
|
|
@ -381,6 +381,9 @@ class Postgres(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
|
@ -401,10 +404,13 @@ class Postgres(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.MapFromEntries: no_map_from_entries_sql,
|
||||
exp.Min: min_or_least,
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
|
||||
exp.Merge: transforms.preprocess([_remove_target_from_merge]),
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.PercentileDisc: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
|
|
|
@ -237,6 +237,7 @@ class Presto(Dialect):
|
|||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
|
@ -310,6 +311,7 @@ class Presto(Dialect):
|
|||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Last: _first_last_sql,
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
|
@ -360,6 +362,7 @@ class Presto(Dialect):
|
|||
exp.WithinGroup: transforms.preprocess(
|
||||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]),
|
||||
}
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
|
|
|
@ -30,6 +30,8 @@ class Redshift(Postgres):
|
|||
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
**Postgres.TIME_MAPPING,
|
||||
|
@ -38,8 +40,6 @@ class Redshift(Postgres):
|
|||
}
|
||||
|
||||
class Parser(Postgres.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"ADD_MONTHS": lambda args: exp.DateAdd(
|
||||
|
|
|
@ -202,6 +202,7 @@ class Snowflake(Dialect):
|
|||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
TIME_MAPPING = {
|
||||
"YYYY": "%Y",
|
||||
|
@ -234,7 +235,6 @@ class Snowflake(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
|
|
@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
|
|||
binary_from_function,
|
||||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
is_parse_json,
|
||||
pivot_column_names,
|
||||
rename_func,
|
||||
trim_sql,
|
||||
|
@ -242,10 +243,11 @@ class Spark2(Hive):
|
|||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
|
||||
if is_parse_json(expression.this):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.is_type("json"):
|
||||
|
||||
if is_parse_json(expression):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
|
|
@ -5,6 +5,8 @@ from sqlglot.dialects.presto import Presto
|
|||
|
||||
|
||||
class Trino(Presto):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
|
@ -13,6 +15,3 @@ class Trino(Presto):
|
|||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
||||
class Parser(Presto.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
|
|
@ -580,7 +580,6 @@ class TSQL(Dialect):
|
|||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
LIMIT_IS_TOP = True
|
||||
QUERY_HINTS = False
|
||||
RETURNING_END = False
|
||||
|
|
|
@ -1321,7 +1321,13 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
|
||||
class IndexColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False}
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"schema": True,
|
||||
"kind": False,
|
||||
"index_type": False,
|
||||
"options": False,
|
||||
}
|
||||
|
||||
|
||||
class InlineLengthColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -1354,7 +1360,7 @@ class TitleColumnConstraint(ColumnConstraintKind):
|
|||
|
||||
|
||||
class UniqueColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": False}
|
||||
arg_types = {"this": False, "index_type": False}
|
||||
|
||||
|
||||
class UppercaseColumnConstraint(ColumnConstraintKind):
|
||||
|
@ -4366,6 +4372,10 @@ class Extract(Func):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Timestamp(Func):
|
||||
arg_types = {"this": False, "expression": False}
|
||||
|
||||
|
||||
class TimestampAdd(Func, TimeUnit):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
@ -4579,6 +4589,11 @@ class JSONArrayContains(Binary, Predicate, Func):
|
|||
_sql_names = ["JSON_ARRAY_CONTAINS"]
|
||||
|
||||
|
||||
class ParseJSON(Func):
|
||||
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
|
||||
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
|
||||
|
||||
|
||||
class Least(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
|
|
@ -705,7 +705,9 @@ class Generator:
|
|||
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
return f"UNIQUE{this}"
|
||||
index_type = expression.args.get("index_type")
|
||||
index_type = f" USING {index_type}" if index_type else ""
|
||||
return f"UNIQUE{this}{index_type}"
|
||||
|
||||
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
@ -2740,13 +2742,13 @@ class Generator:
|
|||
kind = f"{kind} INDEX" if kind else "INDEX"
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
type_ = self.sql(expression, "type")
|
||||
type_ = f" USING {type_}" if type_ else ""
|
||||
index_type = self.sql(expression, "index_type")
|
||||
index_type = f" USING {index_type}" if index_type else ""
|
||||
schema = self.sql(expression, "schema")
|
||||
schema = f" {schema}" if schema else ""
|
||||
options = self.expressions(expression, key="options", sep=" ")
|
||||
options = f" {options}" if options else ""
|
||||
return f"{kind}{this}{type_}{schema}{options}"
|
||||
return f"{kind}{this}{index_type}{schema}{options}"
|
||||
|
||||
def nvl2_sql(self, expression: exp.Nvl2) -> str:
|
||||
if self.NVL2_SUPPORTED:
|
||||
|
|
|
@ -60,8 +60,8 @@ def qualify(
|
|||
The qualified expression.
|
||||
"""
|
||||
schema = ensure_schema(schema, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
expression = normalize_identifiers(expression, dialect=dialect)
|
||||
expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema)
|
||||
|
||||
if isolate_tables:
|
||||
expression = isolate_table_selects(expression, schema=schema)
|
||||
|
|
|
@ -820,7 +820,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
||||
TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {}
|
||||
TYPE_LITERAL_PARSERS = {
|
||||
exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this),
|
||||
}
|
||||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
|
@ -848,6 +850,8 @@ class Parser(metaclass=_Parser):
|
|||
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
|
||||
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
|
||||
|
||||
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
|
||||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT}
|
||||
|
@ -863,8 +867,6 @@ class Parser(metaclass=_Parser):
|
|||
LOG_BASE_FIRST = True
|
||||
LOG_DEFAULTS_TO_LN = False
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
|
||||
# Whether or not ADD is present for each column added by ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
|
||||
|
@ -892,6 +894,7 @@ class Parser(metaclass=_Parser):
|
|||
UNNEST_COLUMN_ONLY: bool = False
|
||||
ALIAS_POST_TABLESAMPLE: bool = False
|
||||
STRICT_STRING_CONCAT = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
NORMALIZE_FUNCTIONS = "upper"
|
||||
NULL_ORDERING: str = "nulls_are_small"
|
||||
SHOW_TRIE: t.Dict = {}
|
||||
|
@ -2692,7 +2695,7 @@ class Parser(metaclass=_Parser):
|
|||
expressions = self._parse_csv(self._parse_primary)
|
||||
else:
|
||||
expressions = None
|
||||
num = self._parse_number()
|
||||
num = self._parse_primary()
|
||||
|
||||
if self._match_text_seq("BUCKET"):
|
||||
bucket_numerator = self._parse_number()
|
||||
|
@ -2914,6 +2917,10 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
connect = self._parse_conjunction()
|
||||
self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
|
||||
|
||||
if not start and self._match(TokenType.START_WITH):
|
||||
start = self._parse_conjunction()
|
||||
|
||||
return self.expression(exp.Connect, start=start, connect=connect)
|
||||
|
||||
def _parse_order(
|
||||
|
@ -2985,7 +2992,7 @@ class Parser(metaclass=_Parser):
|
|||
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
|
||||
direction = self._prev.text if direction else "FIRST"
|
||||
|
||||
count = self._parse_number()
|
||||
count = self._parse_field(tokens=self.FETCH_TOKENS)
|
||||
percent = self._match(TokenType.PERCENT)
|
||||
|
||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||
|
@ -3272,7 +3279,12 @@ class Parser(metaclass=_Parser):
|
|||
if tokens[0].token_type in self.TYPE_TOKENS:
|
||||
self._prev = tokens[0]
|
||||
elif self.SUPPORTS_USER_DEFINED_TYPES:
|
||||
return exp.DataType.build(identifier.name, udt=True)
|
||||
type_name = identifier.name
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
type_name = f"{type_name}.{self._advance_any() and self._prev.text}"
|
||||
|
||||
return exp.DataType.build(type_name, udt=True)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
|
@ -3816,7 +3828,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_unique(self) -> exp.UniqueColumnConstraint:
|
||||
self._match_text_seq("KEY")
|
||||
return self.expression(
|
||||
exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False))
|
||||
exp.UniqueColumnConstraint,
|
||||
this=self._parse_schema(self._parse_id_var(any_token=False)),
|
||||
index_type=self._match(TokenType.USING) and self._advance_any() and self._prev.text,
|
||||
)
|
||||
|
||||
def _parse_key_constraint_options(self) -> t.List[str]:
|
||||
|
|
|
@ -398,9 +398,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
"""
|
||||
if schema_type not in self._type_mapping_cache:
|
||||
dialect = dialect or self.dialect
|
||||
udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES
|
||||
|
||||
try:
|
||||
expression = exp.DataType.build(schema_type, dialect=dialect)
|
||||
expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt)
|
||||
self._type_mapping_cache[schema_type] = expression
|
||||
except AttributeError:
|
||||
in_dialect = f" in dialect {dialect}" if dialect else ""
|
||||
|
|
|
@ -224,10 +224,27 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
|
||||
|
||||
|
||||
def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, PERCENTILES)
|
||||
and not isinstance(expression.parent, exp.WithinGroup)
|
||||
and expression.expression
|
||||
):
|
||||
column = expression.this.pop()
|
||||
expression.set("this", expression.expression.pop())
|
||||
order = exp.Order(expressions=[exp.Ordered(this=column)])
|
||||
expression = exp.WithinGroup(this=expression, expression=order)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, exp.WithinGroup)
|
||||
and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
|
||||
and isinstance(expression.this, PERCENTILES)
|
||||
and isinstance(expression.expression, exp.Order)
|
||||
):
|
||||
quantile = expression.this.this
|
||||
|
@ -294,10 +311,13 @@ def preprocess(
|
|||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
if expression_type is type(expression):
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
|
@ -307,3 +327,12 @@ def preprocess(
|
|||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Timestamp) and not expression.expression:
|
||||
return exp.cast(
|
||||
expression.this,
|
||||
to=exp.DataType.Type.TIMESTAMP,
|
||||
)
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue