1
0
Fork 0

Merging upstream version 15.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:58:40 +01:00
parent 2e6df1bcfa
commit 3d4adf9c16
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
81 changed files with 40321 additions and 37940 deletions

View file

@ -327,6 +327,8 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):

View file

@ -27,14 +27,15 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
"ATTACH": TokenType.COMMAND,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
@ -97,7 +98,6 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
TokenType.ANY,
TokenType.ASOF,
TokenType.SEMI,
TokenType.ANTI,
TokenType.SETTINGS,
@ -182,7 +182,7 @@ class ClickHouse(Dialect):
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
def _parse_join_side_and_kind(
def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
is_global = self._match(TokenType.GLOBAL) and self._prev
@ -201,7 +201,7 @@ class ClickHouse(Dialect):
join = super()._parse_join(skip_join_token)
if join:
join.set("global", join.args.pop("natural", None))
join.set("global", join.args.pop("method", None))
return join
def _parse_function(
@ -245,6 +245,23 @@ class ClickHouse(Dialect):
) -> t.List[t.Optional[exp.Expression]]:
return super()._parse_wrapped_id_vars(optional=True)
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
) -> exp.Expression:
return super()._parse_primary_key(
wrapped_optional=wrapped_optional or in_props, in_props=in_props
)
def _parse_on_property(self) -> t.Optional[exp.Property]:
index = self._index
if self._match_text_seq("CLUSTER"):
this = self._parse_id_var()
if this:
return self.expression(exp.OnCluster, this=this)
else:
self._retreat(index)
return None
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@ -292,6 +309,7 @@ class ClickHouse(Dialect):
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCluster: exp.Properties.Location.POST_NAME,
}
JOIN_HINTS = False
@ -299,6 +317,18 @@ class ClickHouse(Dialect):
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
# there's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
ON_CLUSTER_TARGETS = {
"DATABASE",
"TABLE",
"VIEW",
"DICTIONARY",
"INDEX",
"FUNCTION",
"NAMED COLLECTION",
}
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
@ -321,3 +351,21 @@ class ClickHouse(Dialect):
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return f"ON CLUSTER {self.sql(expression, 'this')}"
def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
kind = self.sql(expression, "kind").upper()
if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
this_properties = " ".join(
[self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]]
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
return super().createable_sql(expression, locations)

View file

@ -104,6 +104,10 @@ class _Dialect(type):
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
klass.tokenizer_class.identifiers_can_start_with_digit = (
klass.identifiers_can_start_with_digit
)
return klass
@ -111,6 +115,7 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
identifiers_can_start_with_digit = False
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
@ -231,6 +236,7 @@ class Dialect(metaclass=_Dialect):
"time_trie": self.inverse_time_trie,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
**opts,
@ -443,7 +449,7 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
if isinstance(this, exp.Cast) and this.is_type("date"):
return exp.DateTrunc(unit=unit, this=this)
return exp.TimestampTrunc(this=this, unit=unit)
@ -468,6 +474,25 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
)
def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this, start=exp.Literal.number(1), length=expression.expression
)
)
def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
expression = expression.copy()
return self.sql(
exp.Substring(
this=expression.this,
start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
)
)
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"

View file

@ -71,7 +71,7 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
left_to_substring_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@ -17,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
@ -89,7 +91,7 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
annotate_types(this)
if this.type.is_type(exp.DataType.Type.JSON):
if this.type.is_type("json"):
return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
@ -149,6 +151,7 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
class Hive(Dialect):
alias_post_tablesample = True
identifiers_can_start_with_digit = True
time_mapping = {
"y": "%Y",
@ -190,7 +193,6 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -276,6 +278,39 @@ class Hive(Dialect):
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
}
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
"""
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
STRING in all contexts except for schema definitions. For example, this is in Spark v3.4.0:
spark-sql (default)> select cast(1234 as varchar(2));
23/06/06 15:51:18 WARN CharVarcharUtils: The Spark cast operator does not support
char/varchar type and simply treats them as string type. Please use string type
directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString
to true, so that Spark treat them as string type as same as Spark 3.0 and earlier
1234
Time taken: 4.265 seconds, Fetched 1 row(s)
This shows that Spark doesn't truncate the value into '12', which is inconsistent with
what other dialects (e.g. postgres) do, so we need to drop the length to transpile correctly.
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
"""
this = super()._parse_types(check_func=check_func, schema=schema)
if this and not schema:
return this.transform(
lambda node: node.replace(exp.DataType.build("text"))
if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
else node,
copy=False,
)
return this
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@ -323,6 +358,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: _json_format_sql,
exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@ -332,6 +368,7 @@ class Hive(Dialect):
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),

View file

@ -186,9 +186,6 @@ class MySQL(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
"LOCATE": locate_to_strposition,
"STR_TO_DATE": _str_to_date,
}

View file

@ -18,7 +18,9 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
trim_sql,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
@ -104,7 +106,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
@ -353,12 +355,13 @@ class Postgres(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,

View file

@ -8,10 +8,12 @@ from sqlglot.dialects.dialect import (
date_trunc_to_time,
format_time_lambda,
if_sql,
left_to_substring_sql,
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
rename_func,
right_to_substring_sql,
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
@ -30,7 +32,7 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
if expression.is_type("timestamptz"):
sql = f"{sql} WITH TIME ZONE"
return sql
@ -240,6 +242,7 @@ class Presto(Dialect):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
IS_BOOL = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@ -272,6 +275,7 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: self.func(
@ -292,11 +296,13 @@ class Presto(Dialect):
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@ -319,6 +325,7 @@ class Presto(Dialect):
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
@ -356,7 +363,7 @@ class Presto(Dialect):
else:
target_type = None
if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
if target_type and target_type.is_type("timestamp"):
to = target_type.copy()
if target_type is start.to:

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -24,26 +25,29 @@ class Redshift(Postgres):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
),
"NVL": exp.Coalesce.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
CONVERT_TYPE_FIRST = True
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func)
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func, schema=schema)
if (
isinstance(this, exp.DataType)
and this.this == exp.DataType.Type.VARCHAR
and this.is_type("varchar")
and this.expressions
and this.expressions[0].this == exp.column("MAX")
):
@ -99,10 +103,12 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
@ -158,7 +164,7 @@ class Redshift(Postgres):
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
`TEXT` to `VARCHAR`.
"""
if expression.this == exp.DataType.Type.TEXT:
if expression.is_type("text"):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")

View file

@ -153,9 +153,9 @@ def _nullifzero_to_if(args: t.List) -> exp.Expression:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
if expression.is_type("array"):
return "ARRAY"
elif expression.this == exp.DataType.Type.MAP:
elif expression.is_type("map"):
return "OBJECT"
return self.datatype_sql(expression)

View file

@ -110,11 +110,6 @@ class Spark2(Hive):
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
@ -123,14 +118,6 @@ class Spark2(Hive):
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
@ -240,17 +227,17 @@ class Spark2(Hive):
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
TRANSFORMS.pop(exp.Right)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
if expression.is_type("json"):
return self.func("TO_JSON", expression.this)
return super(Hive.Generator, self).cast_sql(expression)
@ -260,7 +247,7 @@ class Spark2(Hive):
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
and expression.parent.is_type(exp.DataType.Type.STRUCT)
and expression.parent.is_type("struct")
else sep,
)

View file

@ -132,7 +132,7 @@ class SQLite(Dialect):
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:
if expression.to.this == exp.DataType.Type.DATE:
if expression.is_type("date"):
return self.func("DATE", expression.this)
return super().cast_sql(expression)

View file

@ -183,3 +183,20 @@ class Teradata(Dialect):
each_sql = f" EACH {each_sql}" if each_sql else ""
return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"
def createable_sql(
self,
expression: exp.Create,
locations: dict[exp.Properties.Location, list[exp.Property]],
) -> str:
kind = self.sql(expression, "kind").upper()
if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
this_properties = self.properties(
exp.Properties(expressions=locations[exp.Properties.Location.POST_NAME]),
wrapped=False,
prefix=",",
)
this_schema = self.schema_columns_sql(expression.this)
return f"{this_name}{this_properties}{self.sep()}{this_schema}"
return super().createable_sql(expression, locations)

View file

@ -1653,11 +1653,15 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
"natural": False,
"method": False,
"global": False,
"hint": False,
}
@property
def method(self) -> str:
return self.text("method").upper()
@property
def kind(self) -> str:
return self.text("kind").upper()
@ -1913,6 +1917,24 @@ class LanguageProperty(Property):
arg_types = {"this": True}
class DictProperty(Property):
arg_types = {"this": True, "kind": True, "settings": False}
class DictSubProperty(Property):
pass
class DictRange(Property):
arg_types = {"this": True, "min": True, "max": True}
# Clickhouse CREATE ... ON CLUSTER modifier
# https://clickhouse.com/docs/en/sql-reference/distributed-ddl
class OnCluster(Property):
arg_types = {"this": True}
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
@ -2797,12 +2819,12 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
parse_args = {"dialect": dialect, **opts}
parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts}
try:
expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore
expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args)
except ParseError:
expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore
expression = maybe_parse(expression, into=(Join, Expression), **parse_args)
join = expression if isinstance(expression, Join) else Join(this=expression)
@ -2810,14 +2832,14 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
natural: t.Optional[Token]
method: t.Optional[Token]
side: t.Optional[Token]
kind: t.Optional[Token]
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if method:
join.set("method", method.text)
if side:
join.set("side", side.text)
if kind:
@ -3222,6 +3244,18 @@ class DataType(Expression):
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
INT8MULTIRANGE = auto()
NUMRANGE = auto()
NUMMULTIRANGE = auto()
TSRANGE = auto()
TSMULTIRANGE = auto()
TSTZRANGE = auto()
TSTZMULTIRANGE = auto()
DATERANGE = auto()
DATEMULTIRANGE = auto()
DECIMAL = auto()
DOUBLE = auto()
FLOAT = auto()
@ -3331,8 +3365,8 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, dtype: DataType.Type) -> bool:
return self.this == dtype
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
return any(self.this == DataType.build(dtype).this for dtype in dtypes)
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@ -3846,8 +3880,8 @@ class Cast(Func):
def output_name(self) -> str:
return self.name
def is_type(self, dtype: DataType.Type) -> bool:
return self.to.is_type(dtype)
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
return self.to.is_type(*dtypes)
class CastToStrType(Func):
@ -4130,8 +4164,16 @@ class Least(Func):
is_var_len_args = True
class Left(Func):
arg_types = {"this": True, "expression": True}
class Right(Func):
arg_types = {"this": True, "expression": True}
class Length(Func):
pass
_sql_names = ["LENGTH", "LEN"]
class Levenshtein(Func):
@ -4356,6 +4398,10 @@ class NumberToStr(Func):
arg_types = {"this": True, "format": True}
class FromBase(Func):
arg_types = {"this": True, "expression": True}
class Struct(Func):
arg_types = {"expressions": True}
is_var_len_args = True

View file

@ -44,6 +44,8 @@ class Generator:
Default: "upper"
alias_post_tablesample (bool): if the table alias comes after tablesample
Default: False
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
Default: False
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
unsupported expressions. Default ErrorLevel.WARN.
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
@ -188,6 +190,8 @@ class Generator:
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
@ -233,6 +237,7 @@ class Generator:
JOIN_HINTS = True
TABLE_HINTS = True
IS_BOOL = True
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
@ -264,6 +269,7 @@ class Generator:
"index_offset",
"unnest_column_only",
"alias_post_tablesample",
"identifiers_can_start_with_digit",
"normalize_functions",
"unsupported_level",
"unsupported_messages",
@ -304,6 +310,7 @@ class Generator:
index_offset=0,
unnest_column_only=False,
alias_post_tablesample=False,
identifiers_can_start_with_digit=False,
normalize_functions="upper",
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
@ -337,6 +344,7 @@ class Generator:
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
self.alias_post_tablesample = alias_post_tablesample
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
self.normalize_functions = normalize_functions
self.unsupported_level = unsupported_level
self.unsupported_messages = []
@ -634,35 +642,31 @@ class Generator:
this = f" {this}" if this else ""
return f"UNIQUE{this}"
def createable_sql(
self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]]
) -> str:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
properties = expression.args.get("properties")
properties_exp = expression.copy()
properties_locs = self.locate_properties(properties) if properties else {}
this = self.createable_sql(expression, properties_locs)
properties_sql = ""
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
exp.Properties.Location.POST_WITH
):
properties_exp.set(
"properties",
properties_sql = self.sql(
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA],
*properties_locs[exp.Properties.Location.POST_WITH],
]
),
)
)
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME):
this_name = self.sql(expression.this, "this")
this_properties = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]),
wrapped=False,
)
this_schema = f"({self.expressions(expression.this)})"
this = f"{this_name}, {this_properties} {this_schema}"
properties_sql = ""
else:
this = self.sql(expression, "this")
properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
if expression_sql:
@ -894,6 +898,7 @@ class Generator:
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@ -1082,7 +1087,7 @@ class Generator:
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this: str = f" {this}" if expression.this else ""
this = f" {self.sql(expression, 'this')}" if expression.this else ""
for_or_in = expression.args.get("for_or_in")
lock_type = expression.args.get("lock_type")
override = " OVERRIDE" if expression.args.get("override") else ""
@ -1313,7 +1318,7 @@ class Generator:
op_sql = " ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
expression.method,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
@ -1573,9 +1578,12 @@ class Generator:
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
sql = self.schema_columns_sql(expression)
return f"{this}{sql}"
def schema_columns_sql(self, expression: exp.Schema) -> str:
return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
@ -1643,32 +1651,26 @@ class Generator:
def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.partition_by_sql(expression)
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec")
spec_sql = " " + self.windowspec_sql(spec) if spec else ""
order = self.order_sql(order, flat=True) if order else ""
spec = self.sql(expression, "spec")
alias = self.sql(expression, "alias")
over = self.sql(expression, "over") or "OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
first = expression.args.get("first")
if first is not None:
first = " FIRST " if first else " LAST "
first = first or ""
if first is None:
first = ""
else:
first = "FIRST" if first else "LAST"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
window_args = alias + first + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
args = " ".join(arg for arg in (alias, first, partition, order, spec) if arg)
return f"{this} ({args})"
def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str:
partition = self.expressions(expression, key="partition_by", flat=True)
@ -2125,6 +2127,10 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
@ -2322,6 +2328,25 @@ class Generator:
return self.sql(exp.cast(expression.this, "text"))
def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
settings_sql = self.expressions(expression, key="settings", sep=" ")
args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()"
return f"{this}({kind}{args})"
def dictrange_sql(self, expression: exp.DictRange) -> str:
this = self.sql(expression, "this")
max = self.sql(expression, "max")
min = self.sql(expression, "min")
return f"{this}(MIN {min} MAX {max})"
def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}"
def oncluster_sql(self, expression: exp.OnCluster) -> str:
return ""
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -1,7 +1,7 @@
from sqlglot import exp
from sqlglot.helper import tsort
JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
JOIN_ATTRS = ("on", "side", "kind", "using", "method")
def optimize_joins(expression):

View file

@ -10,10 +10,10 @@ def pushdown_predicates(expression):
Example:
>>> import sqlglot
>>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize

View file

@ -155,6 +155,18 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATETIME64,
TokenType.DATE,
TokenType.INT4RANGE,
TokenType.INT4MULTIRANGE,
TokenType.INT8RANGE,
TokenType.INT8MULTIRANGE,
TokenType.NUMRANGE,
TokenType.NUMMULTIRANGE,
TokenType.TSRANGE,
TokenType.TSMULTIRANGE,
TokenType.TSTZRANGE,
TokenType.TSTZMULTIRANGE,
TokenType.DATERANGE,
TokenType.DATEMULTIRANGE,
TokenType.DECIMAL,
TokenType.BIGDECIMAL,
TokenType.UUID,
@ -193,6 +205,7 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
TokenType.DICTIONARY,
}
CREATABLES = {
@ -220,6 +233,7 @@ class Parser(metaclass=_Parser):
TokenType.DELETE,
TokenType.DESC,
TokenType.DESCRIBE,
TokenType.DICTIONARY,
TokenType.DIV,
TokenType.END,
TokenType.EXECUTE,
@ -272,6 +286,7 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.ASOF,
TokenType.FULL,
TokenType.LEFT,
TokenType.LOCK,
@ -375,6 +390,11 @@ class Parser(metaclass=_Parser):
TokenType.EXCEPT,
}
JOIN_METHODS = {
TokenType.NATURAL,
TokenType.ASOF,
}
JOIN_SIDES = {
TokenType.LEFT,
TokenType.RIGHT,
@ -465,7 +485,7 @@ class Parser(metaclass=_Parser):
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
"JOIN_TYPE": lambda self: self._parse_join_parts(),
}
STATEMENT_PARSERS = {
@ -580,6 +600,8 @@ class Parser(metaclass=_Parser):
),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"),
"LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"),
"LIKE": lambda self: self._parse_create_like(),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"LOCK": lambda self: self._parse_locking(),
@ -594,7 +616,8 @@ class Parser(metaclass=_Parser):
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
@ -603,6 +626,7 @@ class Parser(metaclass=_Parser):
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
),
"SORTKEY": lambda self: self._parse_sortkey(),
"SOURCE": lambda self: self._parse_dict_property(this="SOURCE"),
"STABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
@ -1133,13 +1157,16 @@ class Parser(metaclass=_Parser):
begin = None
clone = None
def extend_props(temp_props: t.Optional[exp.Expression]) -> None:
nonlocal properties
if properties and temp_props:
properties.expressions.extend(temp_props.expressions)
elif temp_props:
properties = temp_props
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
@ -1154,21 +1181,13 @@ class Parser(metaclass=_Parser):
table_parts = self._parse_table_parts(schema=True)
# exp.Properties.Location.POST_NAME
if self._match(TokenType.COMMA):
temp_properties = self._parse_properties(before=True)
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
self._match(TokenType.COMMA)
extend_props(self._parse_properties(before=True))
this = self._parse_schema(this=table_parts)
# exp.Properties.Location.POST_SCHEMA and POST_WITH
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
@ -1178,11 +1197,7 @@ class Parser(metaclass=_Parser):
or self._match(TokenType.WITH, advance=False)
or self._match(TokenType.L_PAREN, advance=False)
):
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
extend_props(self._parse_properties())
expression = self._parse_ddl_select()
@ -1192,11 +1207,7 @@ class Parser(metaclass=_Parser):
index = self._parse_index()
# exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
temp_properties = self._parse_properties()
if properties and temp_properties:
properties.expressions.extend(temp_properties.expressions)
elif temp_properties:
properties = temp_properties
extend_props(self._parse_properties())
if not index:
break
@ -1888,8 +1899,16 @@ class Parser(metaclass=_Parser):
this = self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
this = self._parse_table() if table else self._parse_select(nested=True)
this = self._parse_set_operations(self._parse_query_modifiers(this))
if self._match(TokenType.PIVOT):
this = self._parse_simplified_pivot()
elif self._match(TokenType.FROM):
this = exp.select("*").from_(
t.cast(exp.From, self._parse_from(skip_from_token=True))
)
else:
this = self._parse_table() if table else self._parse_select(nested=True)
this = self._parse_set_operations(self._parse_query_modifiers(this))
self._match_r_paren()
# early return so that subquery unions aren't parsed again
@ -1902,10 +1921,6 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
elif self._match(TokenType.PIVOT):
this = self._parse_simplified_pivot()
elif self._match(TokenType.FROM):
this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
else:
this = None
@ -2154,11 +2169,11 @@ class Parser(metaclass=_Parser):
return expression
def _parse_join_side_and_kind(
def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
return (
self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_METHODS) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
@ -2168,14 +2183,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Join, this=self._parse_table())
index = self._index
natural, side, kind = self._parse_join_side_and_kind()
method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
if not skip_join_token and not join:
self._retreat(index)
kind = None
natural = None
method = None
side = None
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
@ -2187,12 +2202,10 @@ class Parser(metaclass=_Parser):
if outer_apply:
side = Token(TokenType.LEFT, "LEFT")
kwargs: t.Dict[
str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
] = {"this": self._parse_table()}
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()}
if natural:
kwargs["natural"] = True
if method:
kwargs["method"] = method.text
if side:
kwargs["side"] = side.text
if kind:
@ -2205,7 +2218,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
return self.expression(exp.Join, **kwargs) # type: ignore
return self.expression(exp.Join, **kwargs)
def _parse_index(
self,
@ -2886,7 +2899,9 @@ class Parser(metaclass=_Parser):
exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
)
def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
index = self._index
prefix = self._match_text_seq("SYSUDTLIB", ".")
@ -2908,7 +2923,9 @@ class Parser(metaclass=_Parser):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(self._parse_types)
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
)
else:
expressions = self._parse_csv(self._parse_type_size)
@ -2943,7 +2960,9 @@ class Parser(metaclass=_Parser):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(self._parse_types)
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@ -3038,11 +3057,7 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
field = (
self._parse_star()
or self._parse_function(anonymous=True)
or self._parse_id_var()
)
field = self._parse_field(anonymous_func=True)
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@ -3113,10 +3128,11 @@ class Parser(metaclass=_Parser):
self,
any_token: bool = False,
tokens: t.Optional[t.Collection[TokenType]] = None,
anonymous_func: bool = False,
) -> t.Optional[exp.Expression]:
return (
self._parse_primary()
or self._parse_function()
or self._parse_function(anonymous=anonymous_func)
or self._parse_id_var(any_token=any_token, tokens=tokens)
)
@ -3270,7 +3286,7 @@ class Parser(metaclass=_Parser):
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
kind = self._parse_types()
kind = self._parse_types(schema=True)
if self._match_text_seq("FOR", "ORDINALITY"):
return self.expression(exp.ColumnDef, this=this, ordinality=True)
@ -3483,16 +3499,18 @@ class Parser(metaclass=_Parser):
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
)
def _parse_primary_key(self) -> exp.Expression:
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
) -> exp.Expression:
desc = (
self._match_set((TokenType.ASC, TokenType.DESC))
and self._prev.token_type == TokenType.DESC
)
if not self._match(TokenType.L_PAREN, advance=False):
if not in_props and not self._match(TokenType.L_PAREN, advance=False):
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
expressions = self._parse_wrapped_csv(self._parse_field)
expressions = self._parse_wrapped_csv(self._parse_field, optional=wrapped_optional)
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
@ -3509,10 +3527,11 @@ class Parser(metaclass=_Parser):
return this
bracket_kind = self._prev.token_type
expressions: t.List[t.Optional[exp.Expression]]
if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
expressions: t.List[t.Optional[exp.Expression]] = [
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
@ -4011,22 +4030,15 @@ class Parser(metaclass=_Parser):
self,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
prefix_tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
identifier = self._parse_identifier()
if identifier:
return identifier
prefix = ""
if prefix_tokens:
while self._match_set(prefix_tokens):
prefix += self._prev.text
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
quoted = self._prev.token_type == TokenType.STRING
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
return exp.Identifier(this=self._prev.text, quoted=quoted)
return None
@ -4472,6 +4484,44 @@ class Parser(metaclass=_Parser):
size = len(start.text)
return exp.Command(this=text[:size], expression=text[size:])
def _parse_dict_property(self, this: str) -> exp.DictProperty:
settings = []
self._match_l_paren()
kind = self._parse_id_var()
if self._match(TokenType.L_PAREN):
while True:
key = self._parse_id_var()
value = self._parse_primary()
if not key and value is None:
break
settings.append(self.expression(exp.DictSubProperty, this=key, value=value))
self._match(TokenType.R_PAREN)
self._match_r_paren()
return self.expression(
exp.DictProperty,
this=this,
kind=kind.this if kind else None,
settings=settings,
)
def _parse_dict_range(self, this: str) -> exp.DictRange:
self._match_l_paren()
has_min = self._match_text_seq("MIN")
if has_min:
min = self._parse_var() or self._parse_primary()
self._match_text_seq("MAX")
max = self._parse_var() or self._parse_primary()
else:
max = self._parse_var() or self._parse_primary()
min = exp.Literal.number(0)
self._match_r_paren()
return self.expression(exp.DictRange, this=this, min=min, max=max)
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
JSON = t.Union[dict, list, str, float, int, bool]
JSON = t.Union[dict, list, str, float, int, bool, None]
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
@ -24,12 +24,12 @@ def dump(node: Node) -> JSON:
klass = node.__class__.__qualname__
if node.__class__.__module__ != exp.__name__:
klass = f"{node.__module__}.{klass}"
obj = {
obj: t.Dict = {
"class": klass,
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
}
if node.type:
obj["type"] = node.type.sql()
obj["type"] = dump(node.type)
if node.comments:
obj["comments"] = node.comments
if node._meta is not None:
@ -60,7 +60,7 @@ def load(obj: JSON) -> Node:
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
expression.type = obj.get("type")
expression.type = t.cast(exp.DataType, load(obj.get("type")))
expression.comments = obj.get("comments")
expression._meta = obj.get("meta")

View file

@ -113,6 +113,18 @@ class TokenType(AutoName):
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
INT8MULTIRANGE = auto()
NUMRANGE = auto()
NUMMULTIRANGE = auto()
TSRANGE = auto()
TSMULTIRANGE = auto()
TSTZRANGE = auto()
TSTZMULTIRANGE = auto()
DATERANGE = auto()
DATEMULTIRANGE = auto()
UUID = auto()
GEOGRAPHY = auto()
NULLABLE = auto()
@ -167,6 +179,7 @@ class TokenType(AutoName):
DELETE = auto()
DESC = auto()
DESCRIBE = auto()
DICTIONARY = auto()
DISTINCT = auto()
DIV = auto()
DROP = auto()
@ -480,6 +493,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
"ASOF": TokenType.ASOF,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
@ -669,6 +683,18 @@ class Tokenizer(metaclass=_Tokenizer):
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
"DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME,
"INT4RANGE": TokenType.INT4RANGE,
"INT4MULTIRANGE": TokenType.INT4MULTIRANGE,
"INT8RANGE": TokenType.INT8RANGE,
"INT8MULTIRANGE": TokenType.INT8MULTIRANGE,
"NUMRANGE": TokenType.NUMRANGE,
"NUMMULTIRANGE": TokenType.NUMMULTIRANGE,
"TSRANGE": TokenType.TSRANGE,
"TSMULTIRANGE": TokenType.TSMULTIRANGE,
"TSTZRANGE": TokenType.TSTZRANGE,
"TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE,
"DATERANGE": TokenType.DATERANGE,
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
@ -709,8 +735,6 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE: t.Dict = {} # autofilled
IDENTIFIER_CAN_START_WITH_DIGIT = False
__slots__ = (
"sql",
"size",
@ -724,6 +748,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"identifiers_can_start_with_digit",
)
def __init__(self) -> None:
@ -826,6 +851,12 @@ class Tokenizer(metaclass=_Tokenizer):
def _text(self) -> str:
return self.sql[self._start : self._current]
def peek(self, i: int = 0) -> str:
i = self._current + i
if i < self.size:
return self.sql[i]
return ""
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self.tokens.append(
@ -962,8 +993,12 @@ class Tokenizer(metaclass=_Tokenizer):
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
decimal = True
self._advance()
after = self.peek(1)
if after.isdigit() or not after.strip():
decimal = True
self._advance()
else:
return self._add(TokenType.VAR)
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
@ -984,7 +1019,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
elif self.identifiers_can_start_with_digit: # type: ignore
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)

View file

@ -268,6 +268,17 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
return expression
def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, (exp.Cast, exp.TryCast))
and expression.name.lower() == "epoch"
and expression.to.this in exp.DataType.TEMPORAL_TYPES
):
expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: