Merging upstream version 15.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2e6df1bcfa
commit
3d4adf9c16
81 changed files with 40321 additions and 37940 deletions
|
@ -327,6 +327,8 @@ class BigQuery(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
first_arg = seq_get(expression.expressions, 0)
|
||||
if isinstance(first_arg, exp.Subqueryable):
|
||||
|
|
|
@ -27,14 +27,15 @@ class ClickHouse(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
STRING_ESCAPES = ["'", "\\"]
|
||||
BIT_STRINGS = [("0b", "")]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ASOF": TokenType.ASOF,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"DATETIME64": TokenType.DATETIME64,
|
||||
"DICTIONARY": TokenType.DICTIONARY,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
|
@ -97,7 +98,6 @@ class ClickHouse(Dialect):
|
|||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
|
||||
TokenType.ANY,
|
||||
TokenType.ASOF,
|
||||
TokenType.SEMI,
|
||||
TokenType.ANTI,
|
||||
TokenType.SETTINGS,
|
||||
|
@ -182,7 +182,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
|
||||
|
||||
def _parse_join_side_and_kind(
|
||||
def _parse_join_parts(
|
||||
self,
|
||||
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
|
||||
is_global = self._match(TokenType.GLOBAL) and self._prev
|
||||
|
@ -201,7 +201,7 @@ class ClickHouse(Dialect):
|
|||
join = super()._parse_join(skip_join_token)
|
||||
|
||||
if join:
|
||||
join.set("global", join.args.pop("natural", None))
|
||||
join.set("global", join.args.pop("method", None))
|
||||
return join
|
||||
|
||||
def _parse_function(
|
||||
|
@ -245,6 +245,23 @@ class ClickHouse(Dialect):
|
|||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
return super()._parse_wrapped_id_vars(optional=True)
|
||||
|
||||
def _parse_primary_key(
|
||||
self, wrapped_optional: bool = False, in_props: bool = False
|
||||
) -> exp.Expression:
|
||||
return super()._parse_primary_key(
|
||||
wrapped_optional=wrapped_optional or in_props, in_props=in_props
|
||||
)
|
||||
|
||||
def _parse_on_property(self) -> t.Optional[exp.Property]:
|
||||
index = self._index
|
||||
if self._match_text_seq("CLUSTER"):
|
||||
this = self._parse_id_var()
|
||||
if this:
|
||||
return self.expression(exp.OnCluster, this=this)
|
||||
else:
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
class Generator(generator.Generator):
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
|
@ -292,6 +309,7 @@ class ClickHouse(Dialect):
|
|||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.OnCluster: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
JOIN_HINTS = False
|
||||
|
@ -299,6 +317,18 @@ class ClickHouse(Dialect):
|
|||
EXPLICIT_UNION = True
|
||||
GROUPINGS_SEP = ""
|
||||
|
||||
# there's no list in docs, but it can be found in Clickhouse code
|
||||
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
|
||||
ON_CLUSTER_TARGETS = {
|
||||
"DATABASE",
|
||||
"TABLE",
|
||||
"VIEW",
|
||||
"DICTIONARY",
|
||||
"INDEX",
|
||||
"FUNCTION",
|
||||
"NAMED COLLECTION",
|
||||
}
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if isinstance(expression.this, exp.Alias):
|
||||
return self.sql(expression, "this")
|
||||
|
@ -321,3 +351,21 @@ class ClickHouse(Dialect):
|
|||
|
||||
def placeholder_sql(self, expression: exp.Placeholder) -> str:
|
||||
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
|
||||
|
||||
def oncluster_sql(self, expression: exp.OnCluster) -> str:
|
||||
return f"ON CLUSTER {self.sql(expression, 'this')}"
|
||||
|
||||
def createable_sql(
|
||||
self,
|
||||
expression: exp.Create,
|
||||
locations: dict[exp.Properties.Location, list[exp.Property]],
|
||||
) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_properties = " ".join(
|
||||
[self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]]
|
||||
)
|
||||
this_schema = self.schema_columns_sql(expression.this)
|
||||
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
|
||||
return super().createable_sql(expression, locations)
|
||||
|
|
|
@ -104,6 +104,10 @@ class _Dialect(type):
|
|||
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
|
||||
|
||||
klass.tokenizer_class.identifiers_can_start_with_digit = (
|
||||
klass.identifiers_can_start_with_digit
|
||||
)
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -111,6 +115,7 @@ class Dialect(metaclass=_Dialect):
|
|||
index_offset = 0
|
||||
unnest_column_only = False
|
||||
alias_post_tablesample = False
|
||||
identifiers_can_start_with_digit = False
|
||||
normalize_functions: t.Optional[str] = "upper"
|
||||
null_ordering = "nulls_are_small"
|
||||
|
||||
|
@ -231,6 +236,7 @@ class Dialect(metaclass=_Dialect):
|
|||
"time_trie": self.inverse_time_trie,
|
||||
"unnest_column_only": self.unnest_column_only,
|
||||
"alias_post_tablesample": self.alias_post_tablesample,
|
||||
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
|
||||
"normalize_functions": self.normalize_functions,
|
||||
"null_ordering": self.null_ordering,
|
||||
**opts,
|
||||
|
@ -443,7 +449,7 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
|||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
||||
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
|
||||
if isinstance(this, exp.Cast) and this.is_type("date"):
|
||||
return exp.DateTrunc(unit=unit, this=this)
|
||||
return exp.TimestampTrunc(this=this, unit=unit)
|
||||
|
||||
|
@ -468,6 +474,25 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
|
|||
)
|
||||
|
||||
|
||||
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Substring(
|
||||
this=expression.this, start=exp.Literal.number(1), length=expression.expression
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
exp.Substring(
|
||||
this=expression.this,
|
||||
start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
|||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
left_to_substring_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -17,6 +18,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_safe_divide_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
strposition_to_locate_sql,
|
||||
struct_extract_sql,
|
||||
timestrtotime_sql,
|
||||
|
@ -89,7 +91,7 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
|
|||
|
||||
annotate_types(this)
|
||||
|
||||
if this.type.is_type(exp.DataType.Type.JSON):
|
||||
if this.type.is_type("json"):
|
||||
return self.sql(this)
|
||||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
@ -149,6 +151,7 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
|
|||
|
||||
class Hive(Dialect):
|
||||
alias_post_tablesample = True
|
||||
identifiers_can_start_with_digit = True
|
||||
|
||||
time_mapping = {
|
||||
"y": "%Y",
|
||||
|
@ -190,7 +193,6 @@ class Hive(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = True
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -276,6 +278,39 @@ class Hive(Dialect):
|
|||
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
"""
|
||||
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
|
||||
STRING in all contexts except for schema definitions. For example, this is in Spark v3.4.0:
|
||||
|
||||
spark-sql (default)> select cast(1234 as varchar(2));
|
||||
23/06/06 15:51:18 WARN CharVarcharUtils: The Spark cast operator does not support
|
||||
char/varchar type and simply treats them as string type. Please use string type
|
||||
directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString
|
||||
to true, so that Spark treat them as string type as same as Spark 3.0 and earlier
|
||||
|
||||
1234
|
||||
Time taken: 4.265 seconds, Fetched 1 row(s)
|
||||
|
||||
This shows that Spark doesn't truncate the value into '12', which is inconsistent with
|
||||
what other dialects (e.g. postgres) do, so we need to drop the length to transpile correctly.
|
||||
|
||||
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
|
||||
"""
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
|
||||
if this and not schema:
|
||||
return this.transform(
|
||||
lambda node: node.replace(exp.DataType.build("text"))
|
||||
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
|
||||
else node,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
return this
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
|
@ -323,6 +358,7 @@ class Hive(Dialect):
|
|||
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
exp.Map: var_map_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
|
@ -332,6 +368,7 @@ class Hive(Dialect):
|
|||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
|
||||
exp.RegexpSplit: rename_func("SPLIT"),
|
||||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
|
|
|
@ -186,9 +186,6 @@ class MySQL(Dialect):
|
|||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
|
||||
),
|
||||
"LOCATE": locate_to_strposition,
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
}
|
||||
|
|
|
@ -18,7 +18,9 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
str_position_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import binary_range_parser
|
||||
|
@ -104,7 +106,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
|
|||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
@ -353,12 +355,13 @@ class Postgres(Dialect):
|
|||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Substring: _substring_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
|
||||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
|
|
|
@ -8,10 +8,12 @@ from sqlglot.dialects.dialect import (
|
|||
date_trunc_to_time,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
left_to_substring_sql,
|
||||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
struct_extract_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
|
@ -30,7 +32,7 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
|
|||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
sql = self.datatype_sql(expression)
|
||||
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
|
||||
if expression.is_type("timestamptz"):
|
||||
sql = f"{sql} WITH TIME ZONE"
|
||||
return sql
|
||||
|
||||
|
@ -240,6 +242,7 @@ class Presto(Dialect):
|
|||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
IS_BOOL = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -272,6 +275,7 @@ class Presto(Dialect):
|
|||
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
|
@ -292,11 +296,13 @@ class Presto(Dialect):
|
|||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
|
@ -319,6 +325,7 @@ class Presto(Dialect):
|
|||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
|
||||
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
|
@ -356,7 +363,7 @@ class Presto(Dialect):
|
|||
else:
|
||||
target_type = None
|
||||
|
||||
if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
|
||||
if target_type and target_type.is_type("timestamp"):
|
||||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -24,26 +25,29 @@ class Redshift(Postgres):
|
|||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
|
||||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"NVL": exp.Coalesce.from_arg_list,
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
||||
CONVERT_TYPE_FIRST = True
|
||||
|
||||
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(check_func=check_func)
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
|
||||
if (
|
||||
isinstance(this, exp.DataType)
|
||||
and this.this == exp.DataType.Type.VARCHAR
|
||||
and this.is_type("varchar")
|
||||
and this.expressions
|
||||
and this.expressions[0].this == exp.column("MAX")
|
||||
):
|
||||
|
@ -99,10 +103,12 @@ class Redshift(Postgres):
|
|||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.FromBase: rename_func("STRTOL"),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
|
||||
}
|
||||
|
||||
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
|
||||
|
@ -158,7 +164,7 @@ class Redshift(Postgres):
|
|||
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
|
||||
`TEXT` to `VARCHAR`.
|
||||
"""
|
||||
if expression.this == exp.DataType.Type.TEXT:
|
||||
if expression.is_type("text"):
|
||||
expression = expression.copy()
|
||||
expression.set("this", exp.DataType.Type.VARCHAR)
|
||||
precision = expression.args.get("expressions")
|
||||
|
|
|
@ -153,9 +153,9 @@ def _nullifzero_to_if(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.this == exp.DataType.Type.ARRAY:
|
||||
if expression.is_type("array"):
|
||||
return "ARRAY"
|
||||
elif expression.this == exp.DataType.Type.MAP:
|
||||
elif expression.is_type("map"):
|
||||
return "OBJECT"
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
|
|
@ -110,11 +110,6 @@ class Spark2(Hive):
|
|||
**Hive.Parser.FUNCTIONS,
|
||||
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
|
||||
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Literal.number(1),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -123,14 +118,6 @@ class Spark2(Hive):
|
|||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
),
|
||||
"RIGHT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0),
|
||||
start=exp.Sub(
|
||||
this=exp.Length(this=seq_get(args, 0)),
|
||||
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
|
||||
),
|
||||
length=seq_get(args, 1),
|
||||
),
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
|
@ -240,17 +227,17 @@ class Spark2(Hive):
|
|||
TRANSFORMS.pop(exp.ArrayJoin)
|
||||
TRANSFORMS.pop(exp.ArraySort)
|
||||
TRANSFORMS.pop(exp.ILike)
|
||||
TRANSFORMS.pop(exp.Left)
|
||||
TRANSFORMS.pop(exp.Right)
|
||||
|
||||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||
exp.DataType.Type.JSON
|
||||
):
|
||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
return self.func("FROM_JSON", expression.this.this, schema)
|
||||
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||
if expression.is_type("json"):
|
||||
return self.func("TO_JSON", expression.this)
|
||||
|
||||
return super(Hive.Generator, self).cast_sql(expression)
|
||||
|
@ -260,7 +247,7 @@ class Spark2(Hive):
|
|||
expression,
|
||||
sep=": "
|
||||
if isinstance(expression.parent, exp.DataType)
|
||||
and expression.parent.is_type(exp.DataType.Type.STRUCT)
|
||||
and expression.parent.is_type("struct")
|
||||
else sep,
|
||||
)
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ class SQLite(Dialect):
|
|||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if expression.to.this == exp.DataType.Type.DATE:
|
||||
if expression.is_type("date"):
|
||||
return self.func("DATE", expression.this)
|
||||
|
||||
return super().cast_sql(expression)
|
||||
|
|
|
@ -183,3 +183,20 @@ class Teradata(Dialect):
|
|||
each_sql = f" EACH {each_sql}" if each_sql else ""
|
||||
|
||||
return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"
|
||||
|
||||
def createable_sql(
|
||||
self,
|
||||
expression: exp.Create,
|
||||
locations: dict[exp.Properties.Location, list[exp.Property]],
|
||||
) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_properties = self.properties(
|
||||
exp.Properties(expressions=locations[exp.Properties.Location.POST_NAME]),
|
||||
wrapped=False,
|
||||
prefix=",",
|
||||
)
|
||||
this_schema = self.schema_columns_sql(expression.this)
|
||||
return f"{this_name}{this_properties}{self.sep()}{this_schema}"
|
||||
return super().createable_sql(expression, locations)
|
||||
|
|
|
@ -1653,11 +1653,15 @@ class Join(Expression):
|
|||
"side": False,
|
||||
"kind": False,
|
||||
"using": False,
|
||||
"natural": False,
|
||||
"method": False,
|
||||
"global": False,
|
||||
"hint": False,
|
||||
}
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.text("method").upper()
|
||||
|
||||
@property
|
||||
def kind(self) -> str:
|
||||
return self.text("kind").upper()
|
||||
|
@ -1913,6 +1917,24 @@ class LanguageProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class DictProperty(Property):
|
||||
arg_types = {"this": True, "kind": True, "settings": False}
|
||||
|
||||
|
||||
class DictSubProperty(Property):
|
||||
pass
|
||||
|
||||
|
||||
class DictRange(Property):
|
||||
arg_types = {"this": True, "min": True, "max": True}
|
||||
|
||||
|
||||
# Clickhouse CREATE ... ON CLUSTER modifier
|
||||
# https://clickhouse.com/docs/en/sql-reference/distributed-ddl
|
||||
class OnCluster(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class LikeProperty(Property):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
|
||||
|
@ -2797,12 +2819,12 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
Select: the modified expression.
|
||||
"""
|
||||
parse_args = {"dialect": dialect, **opts}
|
||||
parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts}
|
||||
|
||||
try:
|
||||
expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore
|
||||
expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args)
|
||||
except ParseError:
|
||||
expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore
|
||||
expression = maybe_parse(expression, into=(Join, Expression), **parse_args)
|
||||
|
||||
join = expression if isinstance(expression, Join) else Join(this=expression)
|
||||
|
||||
|
@ -2810,14 +2832,14 @@ class Select(Subqueryable):
|
|||
join.this.replace(join.this.subquery())
|
||||
|
||||
if join_type:
|
||||
natural: t.Optional[Token]
|
||||
method: t.Optional[Token]
|
||||
side: t.Optional[Token]
|
||||
kind: t.Optional[Token]
|
||||
|
||||
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||
method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
|
||||
|
||||
if natural:
|
||||
join.set("natural", True)
|
||||
if method:
|
||||
join.set("method", method.text)
|
||||
if side:
|
||||
join.set("side", side.text)
|
||||
if kind:
|
||||
|
@ -3222,6 +3244,18 @@ class DataType(Expression):
|
|||
DATE = auto()
|
||||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
INT4RANGE = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
INT8MULTIRANGE = auto()
|
||||
NUMRANGE = auto()
|
||||
NUMMULTIRANGE = auto()
|
||||
TSRANGE = auto()
|
||||
TSMULTIRANGE = auto()
|
||||
TSTZRANGE = auto()
|
||||
TSTZMULTIRANGE = auto()
|
||||
DATERANGE = auto()
|
||||
DATEMULTIRANGE = auto()
|
||||
DECIMAL = auto()
|
||||
DOUBLE = auto()
|
||||
FLOAT = auto()
|
||||
|
@ -3331,8 +3365,8 @@ class DataType(Expression):
|
|||
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.this == dtype
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
return any(self.this == DataType.build(dtype).this for dtype in dtypes)
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
|
@ -3846,8 +3880,8 @@ class Cast(Func):
|
|||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.to.is_type(dtype)
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
return self.to.is_type(*dtypes)
|
||||
|
||||
|
||||
class CastToStrType(Func):
|
||||
|
@ -4130,8 +4164,16 @@ class Least(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class Left(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Right(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Length(Func):
|
||||
pass
|
||||
_sql_names = ["LENGTH", "LEN"]
|
||||
|
||||
|
||||
class Levenshtein(Func):
|
||||
|
@ -4356,6 +4398,10 @@ class NumberToStr(Func):
|
|||
arg_types = {"this": True, "format": True}
|
||||
|
||||
|
||||
class FromBase(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Struct(Func):
|
||||
arg_types = {"expressions": True}
|
||||
is_var_len_args = True
|
||||
|
|
|
@ -44,6 +44,8 @@ class Generator:
|
|||
Default: "upper"
|
||||
alias_post_tablesample (bool): if the table alias comes after tablesample
|
||||
Default: False
|
||||
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
|
||||
Default: False
|
||||
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
|
||||
unsupported expressions. Default ErrorLevel.WARN.
|
||||
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
|
||||
|
@ -188,6 +190,8 @@ class Generator:
|
|||
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -233,6 +237,7 @@ class Generator:
|
|||
|
||||
JOIN_HINTS = True
|
||||
TABLE_HINTS = True
|
||||
IS_BOOL = True
|
||||
|
||||
RESERVED_KEYWORDS: t.Set[str] = set()
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
|
||||
|
@ -264,6 +269,7 @@ class Generator:
|
|||
"index_offset",
|
||||
"unnest_column_only",
|
||||
"alias_post_tablesample",
|
||||
"identifiers_can_start_with_digit",
|
||||
"normalize_functions",
|
||||
"unsupported_level",
|
||||
"unsupported_messages",
|
||||
|
@ -304,6 +310,7 @@ class Generator:
|
|||
index_offset=0,
|
||||
unnest_column_only=False,
|
||||
alias_post_tablesample=False,
|
||||
identifiers_can_start_with_digit=False,
|
||||
normalize_functions="upper",
|
||||
unsupported_level=ErrorLevel.WARN,
|
||||
null_ordering=None,
|
||||
|
@ -337,6 +344,7 @@ class Generator:
|
|||
self.index_offset = index_offset
|
||||
self.unnest_column_only = unnest_column_only
|
||||
self.alias_post_tablesample = alias_post_tablesample
|
||||
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
|
||||
self.normalize_functions = normalize_functions
|
||||
self.unsupported_level = unsupported_level
|
||||
self.unsupported_messages = []
|
||||
|
@ -634,35 +642,31 @@ class Generator:
|
|||
this = f" {this}" if this else ""
|
||||
return f"UNIQUE{this}"
|
||||
|
||||
def createable_sql(
|
||||
self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]]
|
||||
) -> str:
|
||||
return self.sql(expression, "this")
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
properties = expression.args.get("properties")
|
||||
properties_exp = expression.copy()
|
||||
properties_locs = self.locate_properties(properties) if properties else {}
|
||||
|
||||
this = self.createable_sql(expression, properties_locs)
|
||||
|
||||
properties_sql = ""
|
||||
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
|
||||
exp.Properties.Location.POST_WITH
|
||||
):
|
||||
properties_exp.set(
|
||||
"properties",
|
||||
properties_sql = self.sql(
|
||||
exp.Properties(
|
||||
expressions=[
|
||||
*properties_locs[exp.Properties.Location.POST_SCHEMA],
|
||||
*properties_locs[exp.Properties.Location.POST_WITH],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME):
|
||||
this_name = self.sql(expression.this, "this")
|
||||
this_properties = self.properties(
|
||||
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]),
|
||||
wrapped=False,
|
||||
)
|
||||
this_schema = f"({self.expressions(expression.this)})"
|
||||
this = f"{this_name}, {this_properties} {this_schema}"
|
||||
properties_sql = ""
|
||||
else:
|
||||
this = self.sql(expression, "this")
|
||||
properties_sql = self.sql(properties_exp, "properties")
|
||||
|
||||
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
if expression_sql:
|
||||
|
@ -894,6 +898,7 @@ class Generator:
|
|||
expression.quoted
|
||||
or should_identify(text, self.identify)
|
||||
or lower in self.RESERVED_KEYWORDS
|
||||
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
|
||||
):
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
@ -1082,7 +1087,7 @@ class Generator:
|
|||
|
||||
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
|
||||
kind = expression.args.get("kind")
|
||||
this: str = f" {this}" if expression.this else ""
|
||||
this = f" {self.sql(expression, 'this')}" if expression.this else ""
|
||||
for_or_in = expression.args.get("for_or_in")
|
||||
lock_type = expression.args.get("lock_type")
|
||||
override = " OVERRIDE" if expression.args.get("override") else ""
|
||||
|
@ -1313,7 +1318,7 @@ class Generator:
|
|||
op_sql = " ".join(
|
||||
op
|
||||
for op in (
|
||||
"NATURAL" if expression.args.get("natural") else None,
|
||||
expression.method,
|
||||
"GLOBAL" if expression.args.get("global") else None,
|
||||
expression.side,
|
||||
expression.kind,
|
||||
|
@ -1573,9 +1578,12 @@ class Generator:
|
|||
def schema_sql(self, expression: exp.Schema) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{this} " if this else ""
|
||||
sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
|
||||
sql = self.schema_columns_sql(expression)
|
||||
return f"{this}{sql}"
|
||||
|
||||
def schema_columns_sql(self, expression: exp.Schema) -> str:
|
||||
return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
|
||||
|
||||
def star_sql(self, expression: exp.Star) -> str:
|
||||
except_ = self.expressions(expression, key="except", flat=True)
|
||||
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
|
||||
|
@ -1643,32 +1651,26 @@ class Generator:
|
|||
|
||||
def window_sql(self, expression: exp.Window) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
partition = self.partition_by_sql(expression)
|
||||
|
||||
order = expression.args.get("order")
|
||||
order_sql = self.order_sql(order, flat=True) if order else ""
|
||||
|
||||
partition_sql = partition + " " if partition and order else partition
|
||||
|
||||
spec = expression.args.get("spec")
|
||||
spec_sql = " " + self.windowspec_sql(spec) if spec else ""
|
||||
|
||||
order = self.order_sql(order, flat=True) if order else ""
|
||||
spec = self.sql(expression, "spec")
|
||||
alias = self.sql(expression, "alias")
|
||||
over = self.sql(expression, "over") or "OVER"
|
||||
|
||||
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
|
||||
|
||||
first = expression.args.get("first")
|
||||
if first is not None:
|
||||
first = " FIRST " if first else " LAST "
|
||||
first = first or ""
|
||||
if first is None:
|
||||
first = ""
|
||||
else:
|
||||
first = "FIRST" if first else "LAST"
|
||||
|
||||
if not partition and not order and not spec and alias:
|
||||
return f"{this} {alias}"
|
||||
|
||||
window_args = alias + first + partition_sql + order_sql + spec_sql
|
||||
|
||||
return f"{this} ({window_args.strip()})"
|
||||
args = " ".join(arg for arg in (alias, first, partition, order, spec) if arg)
|
||||
return f"{this} ({args})"
|
||||
|
||||
def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str:
|
||||
partition = self.expressions(expression, key="partition_by", flat=True)
|
||||
|
@ -2125,6 +2127,10 @@ class Generator:
|
|||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def is_sql(self, expression: exp.Is) -> str:
|
||||
if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
|
||||
return self.sql(
|
||||
expression.this if expression.expression.this else exp.not_(expression.this)
|
||||
)
|
||||
return self.binary(expression, "IS")
|
||||
|
||||
def like_sql(self, expression: exp.Like) -> str:
|
||||
|
@ -2322,6 +2328,25 @@ class Generator:
|
|||
|
||||
return self.sql(exp.cast(expression.this, "text"))
|
||||
|
||||
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
settings_sql = self.expressions(expression, key="settings", sep=" ")
|
||||
args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()"
|
||||
return f"{this}({kind}{args})"
|
||||
|
||||
def dictrange_sql(self, expression: exp.DictRange) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
max = self.sql(expression, "max")
|
||||
min = self.sql(expression, "min")
|
||||
return f"{this}(MIN {min} MAX {max})"
|
||||
|
||||
def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str:
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}"
|
||||
|
||||
def oncluster_sql(self, expression: exp.OnCluster) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
|
||||
JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
|
||||
JOIN_ATTRS = ("on", "side", "kind", "using", "method")
|
||||
|
||||
|
||||
def optimize_joins(expression):
|
||||
|
|
|
@ -10,10 +10,10 @@ def pushdown_predicates(expression):
|
|||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
|
||||
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
|
||||
>>> expression = sqlglot.parse_one(sql)
|
||||
>>> pushdown_predicates(expression).sql()
|
||||
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
|
||||
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to optimize
|
||||
|
|
|
@ -155,6 +155,18 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DATETIME,
|
||||
TokenType.DATETIME64,
|
||||
TokenType.DATE,
|
||||
TokenType.INT4RANGE,
|
||||
TokenType.INT4MULTIRANGE,
|
||||
TokenType.INT8RANGE,
|
||||
TokenType.INT8MULTIRANGE,
|
||||
TokenType.NUMRANGE,
|
||||
TokenType.NUMMULTIRANGE,
|
||||
TokenType.TSRANGE,
|
||||
TokenType.TSMULTIRANGE,
|
||||
TokenType.TSTZRANGE,
|
||||
TokenType.TSTZMULTIRANGE,
|
||||
TokenType.DATERANGE,
|
||||
TokenType.DATEMULTIRANGE,
|
||||
TokenType.DECIMAL,
|
||||
TokenType.BIGDECIMAL,
|
||||
TokenType.UUID,
|
||||
|
@ -193,6 +205,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.SCHEMA,
|
||||
TokenType.TABLE,
|
||||
TokenType.VIEW,
|
||||
TokenType.DICTIONARY,
|
||||
}
|
||||
|
||||
CREATABLES = {
|
||||
|
@ -220,6 +233,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DELETE,
|
||||
TokenType.DESC,
|
||||
TokenType.DESCRIBE,
|
||||
TokenType.DICTIONARY,
|
||||
TokenType.DIV,
|
||||
TokenType.END,
|
||||
TokenType.EXECUTE,
|
||||
|
@ -272,6 +286,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||
TokenType.APPLY,
|
||||
TokenType.ASOF,
|
||||
TokenType.FULL,
|
||||
TokenType.LEFT,
|
||||
TokenType.LOCK,
|
||||
|
@ -375,6 +390,11 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.EXCEPT,
|
||||
}
|
||||
|
||||
JOIN_METHODS = {
|
||||
TokenType.NATURAL,
|
||||
TokenType.ASOF,
|
||||
}
|
||||
|
||||
JOIN_SIDES = {
|
||||
TokenType.LEFT,
|
||||
TokenType.RIGHT,
|
||||
|
@ -465,7 +485,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Where: lambda self: self._parse_where(),
|
||||
exp.Window: lambda self: self._parse_named_window(),
|
||||
exp.With: lambda self: self._parse_with(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_parts(),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
|
@ -580,6 +600,8 @@ class Parser(metaclass=_Parser):
|
|||
),
|
||||
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
|
||||
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
|
||||
"LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"),
|
||||
"LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"),
|
||||
"LIKE": lambda self: self._parse_create_like(),
|
||||
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
|
||||
"LOCK": lambda self: self._parse_locking(),
|
||||
|
@ -594,7 +616,8 @@ class Parser(metaclass=_Parser):
|
|||
"PARTITION BY": lambda self: self._parse_partitioned_by(),
|
||||
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
|
||||
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
|
||||
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
|
||||
"RETURNS": lambda self: self._parse_returns(),
|
||||
"ROW": lambda self: self._parse_row(),
|
||||
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
|
||||
|
@ -603,6 +626,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
|
||||
),
|
||||
"SORTKEY": lambda self: self._parse_sortkey(),
|
||||
"SOURCE": lambda self: self._parse_dict_property(this="SOURCE"),
|
||||
"STABLE": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("STABLE")
|
||||
),
|
||||
|
@ -1133,13 +1157,16 @@ class Parser(metaclass=_Parser):
|
|||
begin = None
|
||||
clone = None
|
||||
|
||||
def extend_props(temp_props: t.Optional[exp.Expression]) -> None:
|
||||
nonlocal properties
|
||||
if properties and temp_props:
|
||||
properties.expressions.extend(temp_props.expressions)
|
||||
elif temp_props:
|
||||
properties = temp_props
|
||||
|
||||
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
||||
this = self._parse_user_defined_function(kind=create_token.token_type)
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.extend(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
begin = self._match(TokenType.BEGIN)
|
||||
|
@ -1154,21 +1181,13 @@ class Parser(metaclass=_Parser):
|
|||
table_parts = self._parse_table_parts(schema=True)
|
||||
|
||||
# exp.Properties.Location.POST_NAME
|
||||
if self._match(TokenType.COMMA):
|
||||
temp_properties = self._parse_properties(before=True)
|
||||
if properties and temp_properties:
|
||||
properties.expressions.extend(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
self._match(TokenType.COMMA)
|
||||
extend_props(self._parse_properties(before=True))
|
||||
|
||||
this = self._parse_schema(this=table_parts)
|
||||
|
||||
# exp.Properties.Location.POST_SCHEMA and POST_WITH
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.extend(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
self._match(TokenType.ALIAS)
|
||||
|
||||
|
@ -1178,11 +1197,7 @@ class Parser(metaclass=_Parser):
|
|||
or self._match(TokenType.WITH, advance=False)
|
||||
or self._match(TokenType.L_PAREN, advance=False)
|
||||
):
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.extend(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
expression = self._parse_ddl_select()
|
||||
|
||||
|
@ -1192,11 +1207,7 @@ class Parser(metaclass=_Parser):
|
|||
index = self._parse_index()
|
||||
|
||||
# exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
|
||||
temp_properties = self._parse_properties()
|
||||
if properties and temp_properties:
|
||||
properties.expressions.extend(temp_properties.expressions)
|
||||
elif temp_properties:
|
||||
properties = temp_properties
|
||||
extend_props(self._parse_properties())
|
||||
|
||||
if not index:
|
||||
break
|
||||
|
@ -1888,8 +1899,16 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = self._parse_query_modifiers(this)
|
||||
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
this = self._parse_set_operations(self._parse_query_modifiers(this))
|
||||
if self._match(TokenType.PIVOT):
|
||||
this = self._parse_simplified_pivot()
|
||||
elif self._match(TokenType.FROM):
|
||||
this = exp.select("*").from_(
|
||||
t.cast(exp.From, self._parse_from(skip_from_token=True))
|
||||
)
|
||||
else:
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
this = self._parse_set_operations(self._parse_query_modifiers(this))
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
# early return so that subquery unions aren't parsed again
|
||||
|
@ -1902,10 +1921,6 @@ class Parser(metaclass=_Parser):
|
|||
expressions=self._parse_csv(self._parse_value),
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
elif self._match(TokenType.PIVOT):
|
||||
this = self._parse_simplified_pivot()
|
||||
elif self._match(TokenType.FROM):
|
||||
this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
|
||||
else:
|
||||
this = None
|
||||
|
||||
|
@ -2154,11 +2169,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return expression
|
||||
|
||||
def _parse_join_side_and_kind(
|
||||
def _parse_join_parts(
|
||||
self,
|
||||
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
|
||||
return (
|
||||
self._match(TokenType.NATURAL) and self._prev,
|
||||
self._match_set(self.JOIN_METHODS) and self._prev,
|
||||
self._match_set(self.JOIN_SIDES) and self._prev,
|
||||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
@ -2168,14 +2183,14 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Join, this=self._parse_table())
|
||||
|
||||
index = self._index
|
||||
natural, side, kind = self._parse_join_side_and_kind()
|
||||
method, side, kind = self._parse_join_parts()
|
||||
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
|
||||
join = self._match(TokenType.JOIN)
|
||||
|
||||
if not skip_join_token and not join:
|
||||
self._retreat(index)
|
||||
kind = None
|
||||
natural = None
|
||||
method = None
|
||||
side = None
|
||||
|
||||
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
|
||||
|
@ -2187,12 +2202,10 @@ class Parser(metaclass=_Parser):
|
|||
if outer_apply:
|
||||
side = Token(TokenType.LEFT, "LEFT")
|
||||
|
||||
kwargs: t.Dict[
|
||||
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
|
||||
] = {"this": self._parse_table()}
|
||||
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()}
|
||||
|
||||
if natural:
|
||||
kwargs["natural"] = True
|
||||
if method:
|
||||
kwargs["method"] = method.text
|
||||
if side:
|
||||
kwargs["side"] = side.text
|
||||
if kind:
|
||||
|
@ -2205,7 +2218,7 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match(TokenType.USING):
|
||||
kwargs["using"] = self._parse_wrapped_id_vars()
|
||||
|
||||
return self.expression(exp.Join, **kwargs) # type: ignore
|
||||
return self.expression(exp.Join, **kwargs)
|
||||
|
||||
def _parse_index(
|
||||
self,
|
||||
|
@ -2886,7 +2899,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
|
||||
)
|
||||
|
||||
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
prefix = self._match_text_seq("SYSUDTLIB", ".")
|
||||
|
@ -2908,7 +2923,9 @@ class Parser(metaclass=_Parser):
|
|||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
elif nested:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(check_func=check_func, schema=schema)
|
||||
)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_type_size)
|
||||
|
||||
|
@ -2943,7 +2960,9 @@ class Parser(metaclass=_Parser):
|
|||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_types)
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(check_func=check_func, schema=schema)
|
||||
)
|
||||
|
||||
if not self._match(TokenType.GT):
|
||||
self.raise_error("Expecting >")
|
||||
|
@ -3038,11 +3057,7 @@ class Parser(metaclass=_Parser):
|
|||
else exp.Literal.string(value)
|
||||
)
|
||||
else:
|
||||
field = (
|
||||
self._parse_star()
|
||||
or self._parse_function(anonymous=True)
|
||||
or self._parse_id_var()
|
||||
)
|
||||
field = self._parse_field(anonymous_func=True)
|
||||
|
||||
if isinstance(field, exp.Func):
|
||||
# bigquery allows function calls like x.y.count(...)
|
||||
|
@ -3113,10 +3128,11 @@ class Parser(metaclass=_Parser):
|
|||
self,
|
||||
any_token: bool = False,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
anonymous_func: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
return (
|
||||
self._parse_primary()
|
||||
or self._parse_function()
|
||||
or self._parse_function(anonymous=anonymous_func)
|
||||
or self._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
)
|
||||
|
||||
|
@ -3270,7 +3286,7 @@ class Parser(metaclass=_Parser):
|
|||
# column defs are not really columns, they're identifiers
|
||||
if isinstance(this, exp.Column):
|
||||
this = this.this
|
||||
kind = self._parse_types()
|
||||
kind = self._parse_types(schema=True)
|
||||
|
||||
if self._match_text_seq("FOR", "ORDINALITY"):
|
||||
return self.expression(exp.ColumnDef, this=this, ordinality=True)
|
||||
|
@ -3483,16 +3499,18 @@ class Parser(metaclass=_Parser):
|
|||
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
|
||||
)
|
||||
|
||||
def _parse_primary_key(self) -> exp.Expression:
|
||||
def _parse_primary_key(
|
||||
self, wrapped_optional: bool = False, in_props: bool = False
|
||||
) -> exp.Expression:
|
||||
desc = (
|
||||
self._match_set((TokenType.ASC, TokenType.DESC))
|
||||
and self._prev.token_type == TokenType.DESC
|
||||
)
|
||||
|
||||
if not self._match(TokenType.L_PAREN, advance=False):
|
||||
if not in_props and not self._match(TokenType.L_PAREN, advance=False):
|
||||
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
|
||||
|
||||
expressions = self._parse_wrapped_csv(self._parse_field)
|
||||
expressions = self._parse_wrapped_csv(self._parse_field, optional=wrapped_optional)
|
||||
options = self._parse_key_constraint_options()
|
||||
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
|
||||
|
||||
|
@ -3509,10 +3527,11 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
bracket_kind = self._prev.token_type
|
||||
expressions: t.List[t.Optional[exp.Expression]]
|
||||
|
||||
if self._match(TokenType.COLON):
|
||||
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
|
||||
expressions: t.List[t.Optional[exp.Expression]] = [
|
||||
self.expression(exp.Slice, expression=self._parse_conjunction())
|
||||
]
|
||||
else:
|
||||
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
|
||||
|
||||
|
@ -4011,22 +4030,15 @@ class Parser(metaclass=_Parser):
|
|||
self,
|
||||
any_token: bool = True,
|
||||
tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
prefix_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
identifier = self._parse_identifier()
|
||||
|
||||
if identifier:
|
||||
return identifier
|
||||
|
||||
prefix = ""
|
||||
|
||||
if prefix_tokens:
|
||||
while self._match_set(prefix_tokens):
|
||||
prefix += self._prev.text
|
||||
|
||||
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
quoted = self._prev.token_type == TokenType.STRING
|
||||
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
|
||||
return exp.Identifier(this=self._prev.text, quoted=quoted)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -4472,6 +4484,44 @@ class Parser(metaclass=_Parser):
|
|||
size = len(start.text)
|
||||
return exp.Command(this=text[:size], expression=text[size:])
|
||||
|
||||
def _parse_dict_property(self, this: str) -> exp.DictProperty:
|
||||
settings = []
|
||||
|
||||
self._match_l_paren()
|
||||
kind = self._parse_id_var()
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
while True:
|
||||
key = self._parse_id_var()
|
||||
value = self._parse_primary()
|
||||
|
||||
if not key and value is None:
|
||||
break
|
||||
settings.append(self.expression(exp.DictSubProperty, this=key, value=value))
|
||||
self._match(TokenType.R_PAREN)
|
||||
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
exp.DictProperty,
|
||||
this=this,
|
||||
kind=kind.this if kind else None,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
def _parse_dict_range(self, this: str) -> exp.DictRange:
|
||||
self._match_l_paren()
|
||||
has_min = self._match_text_seq("MIN")
|
||||
if has_min:
|
||||
min = self._parse_var() or self._parse_primary()
|
||||
self._match_text_seq("MAX")
|
||||
max = self._parse_var() or self._parse_primary()
|
||||
else:
|
||||
max = self._parse_var() or self._parse_primary()
|
||||
min = exp.Literal.number(0)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.DictRange, this=this, min=min, max=max)
|
||||
|
||||
def _find_parser(
|
||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||
) -> t.Optional[t.Callable]:
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
from sqlglot import expressions as exp
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
JSON = t.Union[dict, list, str, float, int, bool]
|
||||
JSON = t.Union[dict, list, str, float, int, bool, None]
|
||||
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
|
||||
|
||||
|
||||
|
@ -24,12 +24,12 @@ def dump(node: Node) -> JSON:
|
|||
klass = node.__class__.__qualname__
|
||||
if node.__class__.__module__ != exp.__name__:
|
||||
klass = f"{node.__module__}.{klass}"
|
||||
obj = {
|
||||
obj: t.Dict = {
|
||||
"class": klass,
|
||||
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
|
||||
}
|
||||
if node.type:
|
||||
obj["type"] = node.type.sql()
|
||||
obj["type"] = dump(node.type)
|
||||
if node.comments:
|
||||
obj["comments"] = node.comments
|
||||
if node._meta is not None:
|
||||
|
@ -60,7 +60,7 @@ def load(obj: JSON) -> Node:
|
|||
klass = getattr(module, class_name)
|
||||
|
||||
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
|
||||
expression.type = obj.get("type")
|
||||
expression.type = t.cast(exp.DataType, load(obj.get("type")))
|
||||
expression.comments = obj.get("comments")
|
||||
expression._meta = obj.get("meta")
|
||||
|
||||
|
|
|
@ -113,6 +113,18 @@ class TokenType(AutoName):
|
|||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
DATE = auto()
|
||||
INT4RANGE = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
INT8MULTIRANGE = auto()
|
||||
NUMRANGE = auto()
|
||||
NUMMULTIRANGE = auto()
|
||||
TSRANGE = auto()
|
||||
TSMULTIRANGE = auto()
|
||||
TSTZRANGE = auto()
|
||||
TSTZMULTIRANGE = auto()
|
||||
DATERANGE = auto()
|
||||
DATEMULTIRANGE = auto()
|
||||
UUID = auto()
|
||||
GEOGRAPHY = auto()
|
||||
NULLABLE = auto()
|
||||
|
@ -167,6 +179,7 @@ class TokenType(AutoName):
|
|||
DELETE = auto()
|
||||
DESC = auto()
|
||||
DESCRIBE = auto()
|
||||
DICTIONARY = auto()
|
||||
DISTINCT = auto()
|
||||
DIV = auto()
|
||||
DROP = auto()
|
||||
|
@ -480,6 +493,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"ANY": TokenType.ANY,
|
||||
"ASC": TokenType.ASC,
|
||||
"AS": TokenType.ALIAS,
|
||||
"ASOF": TokenType.ASOF,
|
||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
||||
"BEGIN": TokenType.BEGIN,
|
||||
|
@ -669,6 +683,18 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||
"DATE": TokenType.DATE,
|
||||
"DATETIME": TokenType.DATETIME,
|
||||
"INT4RANGE": TokenType.INT4RANGE,
|
||||
"INT4MULTIRANGE": TokenType.INT4MULTIRANGE,
|
||||
"INT8RANGE": TokenType.INT8RANGE,
|
||||
"INT8MULTIRANGE": TokenType.INT8MULTIRANGE,
|
||||
"NUMRANGE": TokenType.NUMRANGE,
|
||||
"NUMMULTIRANGE": TokenType.NUMMULTIRANGE,
|
||||
"TSRANGE": TokenType.TSRANGE,
|
||||
"TSMULTIRANGE": TokenType.TSMULTIRANGE,
|
||||
"TSTZRANGE": TokenType.TSTZRANGE,
|
||||
"TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE,
|
||||
"DATERANGE": TokenType.DATERANGE,
|
||||
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
|
||||
"UNIQUE": TokenType.UNIQUE,
|
||||
"STRUCT": TokenType.STRUCT,
|
||||
"VARIANT": TokenType.VARIANT,
|
||||
|
@ -709,8 +735,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
|
||||
KEYWORD_TRIE: t.Dict = {} # autofilled
|
||||
|
||||
IDENTIFIER_CAN_START_WITH_DIGIT = False
|
||||
|
||||
__slots__ = (
|
||||
"sql",
|
||||
"size",
|
||||
|
@ -724,6 +748,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"identifiers_can_start_with_digit",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -826,6 +851,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
def _text(self) -> str:
|
||||
return self.sql[self._start : self._current]
|
||||
|
||||
def peek(self, i: int = 0) -> str:
|
||||
i = self._current + i
|
||||
if i < self.size:
|
||||
return self.sql[i]
|
||||
return ""
|
||||
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
self.tokens.append(
|
||||
|
@ -962,8 +993,12 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._peek.isdigit():
|
||||
self._advance()
|
||||
elif self._peek == "." and not decimal:
|
||||
decimal = True
|
||||
self._advance()
|
||||
after = self.peek(1)
|
||||
if after.isdigit() or not after.strip():
|
||||
decimal = True
|
||||
self._advance()
|
||||
else:
|
||||
return self._add(TokenType.VAR)
|
||||
elif self._peek in ("-", "+") and scientific == 1:
|
||||
scientific += 1
|
||||
self._advance()
|
||||
|
@ -984,7 +1019,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._add(TokenType.NUMBER, number_text)
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal)
|
||||
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
|
||||
elif self.identifiers_can_start_with_digit: # type: ignore
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
|
|
|
@ -268,6 +268,17 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
|
|||
return expression
|
||||
|
||||
|
||||
def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
|
||||
if (
|
||||
isinstance(expression, (exp.Cast, exp.TryCast))
|
||||
and expression.name.lower() == "epoch"
|
||||
and expression.to.this in exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue