Merging upstream version 25.5.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
298e7a8147
commit
029b9c2c73
136 changed files with 80990 additions and 72541 deletions
|
@ -13,7 +13,6 @@ from sqlglot.dialects.dialect import (
|
|||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
build_formatted_time,
|
||||
build_timestamp_from_parts,
|
||||
filter_array_using_unnest,
|
||||
if_sql,
|
||||
inline_array_unless_query,
|
||||
|
@ -202,10 +201,35 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
|
|||
def _build_time(args: t.List) -> exp.Func:
|
||||
if len(args) == 1:
|
||||
return exp.TsOrDsToTime(this=args[0])
|
||||
if len(args) == 3:
|
||||
return exp.TimeFromParts.from_arg_list(args)
|
||||
if len(args) == 2:
|
||||
return exp.Time.from_arg_list(args)
|
||||
return exp.TimeFromParts.from_arg_list(args)
|
||||
|
||||
return exp.Anonymous(this="TIME", expressions=args)
|
||||
|
||||
def _build_datetime(args: t.List) -> exp.Func:
|
||||
if len(args) == 1:
|
||||
return exp.TsOrDsToTimestamp.from_arg_list(args)
|
||||
if len(args) == 2:
|
||||
return exp.Datetime.from_arg_list(args)
|
||||
return exp.TimestampFromParts.from_arg_list(args)
|
||||
|
||||
|
||||
def _str_to_datetime_sql(
|
||||
self: BigQuery.Generator, expression: exp.StrToDate | exp.StrToTime
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
dtype = "DATE" if isinstance(expression, exp.StrToDate) else "TIMESTAMP"
|
||||
|
||||
if expression.args.get("safe"):
|
||||
fmt = self.format_time(
|
||||
expression,
|
||||
self.dialect.INVERSE_FORMAT_MAPPING,
|
||||
self.dialect.INVERSE_FORMAT_TRIE,
|
||||
)
|
||||
return f"SAFE_CAST({this} AS {dtype} FORMAT {fmt})"
|
||||
|
||||
fmt = self.format_time(expression)
|
||||
return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone"))
|
||||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
|
@ -215,6 +239,8 @@ class BigQuery(Dialect):
|
|||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
LOG_BASE_FIRST = False
|
||||
HEX_LOWERCASE = True
|
||||
FORCE_EARLY_ALIAS_REF_EXPANSION = True
|
||||
EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
@ -249,7 +275,10 @@ class BigQuery(Dialect):
|
|||
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
|
||||
|
||||
def normalize_identifier(self, expression: E) -> E:
|
||||
if isinstance(expression, exp.Identifier):
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
|
||||
):
|
||||
parent = expression.parent
|
||||
while isinstance(parent, exp.Dot):
|
||||
parent = parent.parent
|
||||
|
@ -308,6 +337,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
KEYWORDS.pop("VALUES")
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
PREFIXED_PIVOT_COLUMNS = True
|
||||
|
@ -323,7 +353,7 @@ class BigQuery(Dialect):
|
|||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATETIME": build_timestamp_from_parts,
|
||||
"DATETIME": _build_datetime,
|
||||
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": binary_from_function(exp.IntDiv),
|
||||
|
@ -334,6 +364,7 @@ class BigQuery(Dialect):
|
|||
"JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
|
||||
),
|
||||
"LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
|
||||
"MD5": exp.MD5Digest.from_arg_list,
|
||||
"TO_HEX": _build_to_hex,
|
||||
"PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
|
||||
|
@ -552,7 +583,7 @@ class BigQuery(Dialect):
|
|||
return bracket
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
EXPLICIT_SET_OP = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
|
@ -644,10 +675,8 @@ class BigQuery(Dialect):
|
|||
exp.StabilityProperty: lambda self, e: (
|
||||
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
|
||||
),
|
||||
exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this),
|
||||
exp.StrToTime: lambda self, e: self.func(
|
||||
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
|
||||
),
|
||||
exp.StrToDate: _str_to_datetime_sql,
|
||||
exp.StrToTime: _str_to_datetime_sql,
|
||||
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
|
||||
exp.TimeFromParts: rename_func("TIME"),
|
||||
exp.TimestampFromParts: rename_func("DATETIME"),
|
||||
|
@ -661,6 +690,7 @@ class BigQuery(Dialect):
|
|||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToTime: rename_func("TIME"),
|
||||
exp.TsOrDsToTimestamp: rename_func("DATETIME"),
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixDate: rename_func("UNIX_DATE"),
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
|
|
|
@ -6,8 +6,8 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arg_max_or_min_no_count,
|
||||
build_date_delta,
|
||||
build_formatted_time,
|
||||
date_delta_sql,
|
||||
inline_array_sql,
|
||||
json_extract_segments,
|
||||
json_path_key_only_name,
|
||||
|
@ -17,10 +17,14 @@ from sqlglot.dialects.dialect import (
|
|||
sha256_sql,
|
||||
var_map_sql,
|
||||
timestamptrunc_sql,
|
||||
unit_to_var,
|
||||
)
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
DATEΤΙΜΕ_DELTA = t.Union[exp.DateAdd, exp.DateDiff, exp.DateSub, exp.TimestampSub, exp.TimestampAdd]
|
||||
|
||||
|
||||
def _build_date_format(args: t.List) -> exp.TimeToStr:
|
||||
expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args)
|
||||
|
@ -77,12 +81,28 @@ def _build_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
|
|||
return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
|
||||
|
||||
|
||||
def _datetime_delta_sql(name: str) -> t.Callable[[Generator, DATEΤΙΜΕ_DELTA], str]:
|
||||
def _delta_sql(self: Generator, expression: DATEΤΙΜΕ_DELTA) -> str:
|
||||
if not expression.unit:
|
||||
return rename_func(name)(self, expression)
|
||||
|
||||
return self.func(
|
||||
name,
|
||||
unit_to_var(expression),
|
||||
expression.expression,
|
||||
expression.this,
|
||||
)
|
||||
|
||||
return _delta_sql
|
||||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SAFE_DIVISION = True
|
||||
LOG_BASE_FIRST: t.Optional[bool] = None
|
||||
FORCE_EARLY_ALIAS_REF_EXPANSION = True
|
||||
|
||||
UNESCAPED_SEQUENCES = {
|
||||
"\\0": "\0",
|
||||
|
@ -128,6 +148,7 @@ class ClickHouse(Dialect):
|
|||
"SYSTEM": TokenType.COMMAND,
|
||||
"PREWHERE": TokenType.PREWHERE,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -138,7 +159,7 @@ class ClickHouse(Dialect):
|
|||
# Tested in ClickHouse's playground, it seems that the following two queries do the same thing
|
||||
# * select x from t1 union all select x from t2 limit 1;
|
||||
# * select x from t1 union all (select x from t2 limit 1);
|
||||
MODIFIERS_ATTACHED_TO_UNION = False
|
||||
MODIFIERS_ATTACHED_TO_SET_OP = False
|
||||
INTERVAL_SPANS = False
|
||||
|
||||
FUNCTIONS = {
|
||||
|
@ -146,19 +167,13 @@ class ClickHouse(Dialect):
|
|||
"ANY": exp.AnyValue.from_arg_list,
|
||||
"ARRAYSUM": exp.ArraySum.from_arg_list,
|
||||
"COUNTIF": _build_count_if,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_DIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_ADD": build_date_delta(exp.DateAdd, default_unit=None),
|
||||
"DATEADD": build_date_delta(exp.DateAdd, default_unit=None),
|
||||
"DATE_DIFF": build_date_delta(exp.DateDiff, default_unit=None),
|
||||
"DATEDIFF": build_date_delta(exp.DateDiff, default_unit=None),
|
||||
"DATE_FORMAT": _build_date_format,
|
||||
"DATE_SUB": build_date_delta(exp.DateSub, default_unit=None),
|
||||
"DATESUB": build_date_delta(exp.DateSub, default_unit=None),
|
||||
"FORMATDATETIME": _build_date_format,
|
||||
"JSONEXTRACTSTRING": build_json_extract_path(
|
||||
exp.JSONExtractScalar, zero_based_indexing=False
|
||||
|
@ -167,6 +182,10 @@ class ClickHouse(Dialect):
|
|||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"RANDCANONICAL": exp.Rand.from_arg_list,
|
||||
"TUPLE": exp.Struct.from_arg_list,
|
||||
"TIMESTAMP_SUB": build_date_delta(exp.TimestampSub, default_unit=None),
|
||||
"TIMESTAMPSUB": build_date_delta(exp.TimestampSub, default_unit=None),
|
||||
"TIMESTAMP_ADD": build_date_delta(exp.TimestampAdd, default_unit=None),
|
||||
"TIMESTAMPADD": build_date_delta(exp.TimestampAdd, default_unit=None),
|
||||
"UNIQ": exp.ApproxDistinct.from_arg_list,
|
||||
"XOR": lambda args: exp.Xor(expressions=args),
|
||||
"MD5": exp.MD5Digest.from_arg_list,
|
||||
|
@ -389,6 +408,23 @@ class ClickHouse(Dialect):
|
|||
"INDEX",
|
||||
}
|
||||
|
||||
def _parse_extract(self) -> exp.Extract | exp.Anonymous:
|
||||
index = self._index
|
||||
this = self._parse_bitwise()
|
||||
if self._match(TokenType.FROM):
|
||||
self._retreat(index)
|
||||
return super()._parse_extract()
|
||||
|
||||
# We return Anonymous here because extract and regexpExtract have different semantics,
|
||||
# so parsing extract(foo, bar) into RegexpExtract can potentially break queries. E.g.,
|
||||
# `extract('foobar', 'b')` works, but CH crashes for `regexpExtract('foobar', 'b')`.
|
||||
#
|
||||
# TODO: can we somehow convert the former into an equivalent `regexpExtract` call?
|
||||
self._match(TokenType.COMMA)
|
||||
return self.expression(
|
||||
exp.Anonymous, this="extract", expressions=[this, self._parse_bitwise()]
|
||||
)
|
||||
|
||||
def _parse_assignment(self) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_assignment()
|
||||
|
||||
|
@ -657,6 +693,12 @@ class ClickHouse(Dialect):
|
|||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
EXPLICIT_SET_OP = True
|
||||
GROUPINGS_SEP = ""
|
||||
SET_OP_MODIFIERS = False
|
||||
SUPPORTS_TABLE_ALIAS_COLUMNS = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
|
@ -730,8 +772,9 @@ class ClickHouse(Dialect):
|
|||
exp.ComputedColumnConstraint: lambda self,
|
||||
e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}",
|
||||
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
|
||||
exp.DateAdd: date_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: date_delta_sql("DATE_DIFF"),
|
||||
exp.DateAdd: _datetime_delta_sql("DATE_ADD"),
|
||||
exp.DateDiff: _datetime_delta_sql("DATE_DIFF"),
|
||||
exp.DateSub: _datetime_delta_sql("DATE_SUB"),
|
||||
exp.Explode: rename_func("arrayJoin"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.IsNan: rename_func("isNaN"),
|
||||
|
@ -754,6 +797,8 @@ class ClickHouse(Dialect):
|
|||
exp.TimeToStr: lambda self, e: self.func(
|
||||
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone")
|
||||
),
|
||||
exp.TimestampAdd: _datetime_delta_sql("TIMESTAMP_ADD"),
|
||||
exp.TimestampSub: _datetime_delta_sql("TIMESTAMP_SUB"),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
|
@ -773,12 +818,6 @@ class ClickHouse(Dialect):
|
|||
exp.OnCluster: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
EXPLICIT_UNION = True
|
||||
GROUPINGS_SEP = ""
|
||||
OUTER_UNION_MODIFIERS = False
|
||||
|
||||
# there's no list in docs, but it can be found in Clickhouse code
|
||||
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
|
||||
ON_CLUSTER_TARGETS = {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, transforms
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, transforms, jsonpath
|
||||
from sqlglot.dialects.dialect import (
|
||||
date_delta_sql,
|
||||
build_date_delta,
|
||||
|
@ -10,27 +12,47 @@ from sqlglot.dialects.spark import Spark
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _build_json_extract(args: t.List) -> exp.JSONExtract:
|
||||
# Transform GET_JSON_OBJECT(expr, '$.<path>') -> expr:<path>
|
||||
this = args[0]
|
||||
path = args[1].name.lstrip("$.")
|
||||
return exp.JSONExtract(this=this, expression=path)
|
||||
|
||||
|
||||
def _timestamp_diff(
|
||||
self: Databricks.Generator, expression: exp.DatetimeDiff | exp.TimestampDiff
|
||||
) -> str:
|
||||
return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this)
|
||||
|
||||
|
||||
def _jsonextract_sql(
|
||||
self: Databricks.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
|
||||
) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"{this}:{expr}"
|
||||
|
||||
|
||||
class Databricks(Spark):
|
||||
SAFE_DIVISION = False
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
|
||||
class JSONPathTokenizer(jsonpath.JSONPathTokenizer):
|
||||
IDENTIFIERS = ["`", '"']
|
||||
|
||||
class Parser(Spark.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = True
|
||||
COLON_IS_JSON_EXTRACT = True
|
||||
COLON_IS_VARIANT_EXTRACT = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**Spark.Parser.FUNCTIONS,
|
||||
"DATEADD": build_date_delta(exp.DateAdd),
|
||||
"DATE_ADD": build_date_delta(exp.DateAdd),
|
||||
"DATEDIFF": build_date_delta(exp.DateDiff),
|
||||
"DATE_DIFF": build_date_delta(exp.DateDiff),
|
||||
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
|
||||
"GET_JSON_OBJECT": _build_json_extract,
|
||||
}
|
||||
|
||||
FACTOR = {
|
||||
|
@ -42,6 +64,8 @@ class Databricks(Spark):
|
|||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
|
||||
QUOTE_JSON_PATH = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**Spark.Generator.TRANSFORMS,
|
||||
|
@ -65,6 +89,9 @@ class Databricks(Spark):
|
|||
transforms.unnest_to_explode,
|
||||
]
|
||||
),
|
||||
exp.JSONExtract: _jsonextract_sql,
|
||||
exp.JSONExtractScalar: _jsonextract_sql,
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from sqlglot import exp
|
|||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import AutoName, flatten, is_int, seq_get
|
||||
from sqlglot.jsonpath import parse as parse_json_path
|
||||
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import TIMEZONES, format_time
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
@ -122,15 +122,21 @@ class _Dialect(type):
|
|||
)
|
||||
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
|
||||
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
|
||||
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
|
||||
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
|
||||
|
||||
base = seq_get(bases, 0)
|
||||
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
|
||||
base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
|
||||
base_parser = (getattr(base, "parser_class", Parser),)
|
||||
base_generator = (getattr(base, "generator_class", Generator),)
|
||||
|
||||
klass.tokenizer_class = klass.__dict__.get(
|
||||
"Tokenizer", type("Tokenizer", base_tokenizer, {})
|
||||
)
|
||||
klass.jsonpath_tokenizer_class = klass.__dict__.get(
|
||||
"JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {})
|
||||
)
|
||||
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
|
||||
klass.generator_class = klass.__dict__.get(
|
||||
"Generator", type("Generator", base_generator, {})
|
||||
|
@ -164,6 +170,8 @@ class _Dialect(type):
|
|||
|
||||
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
|
||||
|
||||
klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
||||
|
@ -232,6 +240,9 @@ class Dialect(metaclass=_Dialect):
|
|||
SUPPORTS_COLUMN_JOIN_MARKS = False
|
||||
"""Whether the old-style outer join (+) syntax is supported."""
|
||||
|
||||
COPY_PARAMS_ARE_CSV = True
|
||||
"""Separator of COPY statement parameters."""
|
||||
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
"""
|
||||
Determines how function names are going to be normalized.
|
||||
|
@ -311,9 +322,44 @@ class Dialect(metaclass=_Dialect):
|
|||
) SELECT c FROM y;
|
||||
"""
|
||||
|
||||
COPY_PARAMS_ARE_CSV = True
|
||||
"""
|
||||
Whether COPY statement parameters are separated by comma or whitespace
|
||||
"""
|
||||
|
||||
FORCE_EARLY_ALIAS_REF_EXPANSION = False
|
||||
"""
|
||||
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
|
||||
|
||||
For example:
|
||||
WITH data AS (
|
||||
SELECT
|
||||
1 AS id,
|
||||
2 AS my_id
|
||||
)
|
||||
SELECT
|
||||
id AS my_id
|
||||
FROM
|
||||
data
|
||||
WHERE
|
||||
my_id = 1
|
||||
GROUP BY
|
||||
my_id,
|
||||
HAVING
|
||||
my_id = 1
|
||||
|
||||
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
|
||||
- BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
|
||||
- Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
|
||||
"""
|
||||
|
||||
EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
|
||||
"""Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
|
||||
|
||||
# --- Autofilled ---
|
||||
|
||||
tokenizer_class = Tokenizer
|
||||
jsonpath_tokenizer_class = JSONPathTokenizer
|
||||
parser_class = Parser
|
||||
generator_class = Generator
|
||||
|
||||
|
@ -323,6 +369,8 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_FORMAT_TRIE: t.Dict = {}
|
||||
|
||||
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
|
@ -342,8 +390,99 @@ class Dialect(metaclass=_Dialect):
|
|||
UNICODE_START: t.Optional[str] = None
|
||||
UNICODE_END: t.Optional[str] = None
|
||||
|
||||
# Separator of COPY statement parameters
|
||||
COPY_PARAMS_ARE_CSV = True
|
||||
DATE_PART_MAPPING = {
|
||||
"Y": "YEAR",
|
||||
"YY": "YEAR",
|
||||
"YYY": "YEAR",
|
||||
"YYYY": "YEAR",
|
||||
"YR": "YEAR",
|
||||
"YEARS": "YEAR",
|
||||
"YRS": "YEAR",
|
||||
"MM": "MONTH",
|
||||
"MON": "MONTH",
|
||||
"MONS": "MONTH",
|
||||
"MONTHS": "MONTH",
|
||||
"D": "DAY",
|
||||
"DD": "DAY",
|
||||
"DAYS": "DAY",
|
||||
"DAYOFMONTH": "DAY",
|
||||
"DAY OF WEEK": "DAYOFWEEK",
|
||||
"WEEKDAY": "DAYOFWEEK",
|
||||
"DOW": "DAYOFWEEK",
|
||||
"DW": "DAYOFWEEK",
|
||||
"WEEKDAY_ISO": "DAYOFWEEKISO",
|
||||
"DOW_ISO": "DAYOFWEEKISO",
|
||||
"DW_ISO": "DAYOFWEEKISO",
|
||||
"DAY OF YEAR": "DAYOFYEAR",
|
||||
"DOY": "DAYOFYEAR",
|
||||
"DY": "DAYOFYEAR",
|
||||
"W": "WEEK",
|
||||
"WK": "WEEK",
|
||||
"WEEKOFYEAR": "WEEK",
|
||||
"WOY": "WEEK",
|
||||
"WY": "WEEK",
|
||||
"WEEK_ISO": "WEEKISO",
|
||||
"WEEKOFYEARISO": "WEEKISO",
|
||||
"WEEKOFYEAR_ISO": "WEEKISO",
|
||||
"Q": "QUARTER",
|
||||
"QTR": "QUARTER",
|
||||
"QTRS": "QUARTER",
|
||||
"QUARTERS": "QUARTER",
|
||||
"H": "HOUR",
|
||||
"HH": "HOUR",
|
||||
"HR": "HOUR",
|
||||
"HOURS": "HOUR",
|
||||
"HRS": "HOUR",
|
||||
"M": "MINUTE",
|
||||
"MI": "MINUTE",
|
||||
"MIN": "MINUTE",
|
||||
"MINUTES": "MINUTE",
|
||||
"MINS": "MINUTE",
|
||||
"S": "SECOND",
|
||||
"SEC": "SECOND",
|
||||
"SECONDS": "SECOND",
|
||||
"SECS": "SECOND",
|
||||
"MS": "MILLISECOND",
|
||||
"MSEC": "MILLISECOND",
|
||||
"MSECS": "MILLISECOND",
|
||||
"MSECOND": "MILLISECOND",
|
||||
"MSECONDS": "MILLISECOND",
|
||||
"MILLISEC": "MILLISECOND",
|
||||
"MILLISECS": "MILLISECOND",
|
||||
"MILLISECON": "MILLISECOND",
|
||||
"MILLISECONDS": "MILLISECOND",
|
||||
"US": "MICROSECOND",
|
||||
"USEC": "MICROSECOND",
|
||||
"USECS": "MICROSECOND",
|
||||
"MICROSEC": "MICROSECOND",
|
||||
"MICROSECS": "MICROSECOND",
|
||||
"USECOND": "MICROSECOND",
|
||||
"USECONDS": "MICROSECOND",
|
||||
"MICROSECONDS": "MICROSECOND",
|
||||
"NS": "NANOSECOND",
|
||||
"NSEC": "NANOSECOND",
|
||||
"NANOSEC": "NANOSECOND",
|
||||
"NSECOND": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"NANOSECS": "NANOSECOND",
|
||||
"EPOCH_SECOND": "EPOCH",
|
||||
"EPOCH_SECONDS": "EPOCH",
|
||||
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
|
||||
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
|
||||
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
|
||||
"TZH": "TIMEZONE_HOUR",
|
||||
"TZM": "TIMEZONE_MINUTE",
|
||||
"DEC": "DECADE",
|
||||
"DECS": "DECADE",
|
||||
"DECADES": "DECADE",
|
||||
"MIL": "MILLENIUM",
|
||||
"MILS": "MILLENIUM",
|
||||
"MILLENIA": "MILLENIUM",
|
||||
"C": "CENTURY",
|
||||
"CENT": "CENTURY",
|
||||
"CENTS": "CENTURY",
|
||||
"CENTURIES": "CENTURY",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
|
@ -371,8 +510,28 @@ class Dialect(metaclass=_Dialect):
|
|||
return dialect
|
||||
if isinstance(dialect, str):
|
||||
try:
|
||||
dialect_name, *kv_pairs = dialect.split(",")
|
||||
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
|
||||
dialect_name, *kv_strings = dialect.split(",")
|
||||
kv_pairs = (kv.split("=") for kv in kv_strings)
|
||||
kwargs = {}
|
||||
for pair in kv_pairs:
|
||||
key = pair[0].strip()
|
||||
value: t.Union[bool | str | None] = None
|
||||
|
||||
if len(pair) == 1:
|
||||
# Default initialize standalone settings to True
|
||||
value = True
|
||||
elif len(pair) == 2:
|
||||
value = pair[1].strip()
|
||||
|
||||
# Coerce the value to boolean if it matches to the truthy/falsy values below
|
||||
value_lower = value.lower()
|
||||
if value_lower in ("true", "1"):
|
||||
value = True
|
||||
elif value_lower in ("false", "0"):
|
||||
value = False
|
||||
|
||||
kwargs[key] = value
|
||||
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid dialect format: '{dialect}'. "
|
||||
|
@ -410,13 +569,15 @@ class Dialect(metaclass=_Dialect):
|
|||
return expression
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
normalization_strategy = kwargs.get("normalization_strategy")
|
||||
normalization_strategy = kwargs.pop("normalization_strategy", None)
|
||||
|
||||
if normalization_strategy is None:
|
||||
self.normalization_strategy = self.NORMALIZATION_STRATEGY
|
||||
else:
|
||||
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
|
||||
|
||||
self.settings = kwargs
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
# Does not currently take dialect state into account
|
||||
return type(self) == other
|
||||
|
@ -518,9 +679,8 @@ class Dialect(metaclass=_Dialect):
|
|||
path_text = path.name
|
||||
if path.is_number:
|
||||
path_text = f"[{path_text}]"
|
||||
|
||||
try:
|
||||
return parse_json_path(path_text)
|
||||
return parse_json_path(path_text, self)
|
||||
except ParseError as e:
|
||||
logger.warning(f"Invalid JSON path syntax. {str(e)}")
|
||||
|
||||
|
@ -548,9 +708,11 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
self._tokenizer = self.tokenizer_class(dialect=self)
|
||||
return self._tokenizer
|
||||
return self.tokenizer_class(dialect=self)
|
||||
|
||||
@property
|
||||
def jsonpath_tokenizer(self) -> JSONPathTokenizer:
|
||||
return self.jsonpath_tokenizer_class(dialect=self)
|
||||
|
||||
def parser(self, **opts) -> Parser:
|
||||
return self.parser_class(dialect=self, **opts)
|
||||
|
@ -739,13 +901,17 @@ def time_format(
|
|||
|
||||
|
||||
def build_date_delta(
|
||||
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||
exp_class: t.Type[E],
|
||||
unit_mapping: t.Optional[t.Dict[str, str]] = None,
|
||||
default_unit: t.Optional[str] = "DAY",
|
||||
) -> t.Callable[[t.List], E]:
|
||||
def _builder(args: t.List) -> E:
|
||||
unit_based = len(args) == 3
|
||||
this = args[2] if unit_based else seq_get(args, 0)
|
||||
unit = args[0] if unit_based else exp.Literal.string("DAY")
|
||||
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
|
||||
unit = None
|
||||
if unit_based or default_unit:
|
||||
unit = args[0] if unit_based else exp.Literal.string(default_unit)
|
||||
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
|
||||
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
|
||||
|
||||
return _builder
|
||||
|
@ -803,19 +969,45 @@ def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.Timesta
|
|||
|
||||
|
||||
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
|
||||
if not expression.expression:
|
||||
zone = expression.args.get("zone")
|
||||
if not zone:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
|
||||
return self.sql(exp.cast(expression.this, target_type))
|
||||
if expression.text("expression").lower() in TIMEZONES:
|
||||
if zone.name.lower() in TIMEZONES:
|
||||
return self.sql(
|
||||
exp.AtTimeZone(
|
||||
this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
|
||||
zone=expression.expression,
|
||||
zone=zone,
|
||||
)
|
||||
)
|
||||
return self.func("TIMESTAMP", expression.this, expression.expression)
|
||||
return self.func("TIMESTAMP", expression.this, zone)
|
||||
|
||||
|
||||
def no_time_sql(self: Generator, expression: exp.Time) -> str:
|
||||
# Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
|
||||
this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
|
||||
expr = exp.cast(
|
||||
exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
|
||||
)
|
||||
return self.sql(expr)
|
||||
|
||||
|
||||
def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
|
||||
this = expression.this
|
||||
expr = expression.expression
|
||||
|
||||
if expr.name.lower() in TIMEZONES:
|
||||
# Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
|
||||
this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
|
||||
this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
|
||||
return self.sql(this)
|
||||
|
||||
this = exp.cast(this, exp.DataType.Type.DATE)
|
||||
expr = exp.cast(expr, exp.DataType.Type.TIME)
|
||||
|
||||
return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
|
||||
|
||||
|
||||
def locate_to_strposition(args: t.List) -> exp.Expression:
|
||||
|
@ -1058,6 +1250,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
|
|||
return exp.Var(this=default) if default else None
|
||||
|
||||
|
||||
@t.overload
|
||||
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
|
||||
pass
|
||||
|
||||
|
||||
@t.overload
|
||||
def map_date_part(
|
||||
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
|
||||
) -> t.Optional[exp.Expression]:
|
||||
pass
|
||||
|
||||
|
||||
def map_date_part(part, dialect: DialectType = Dialect):
|
||||
mapped = (
|
||||
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
|
||||
)
|
||||
return exp.var(mapped) if mapped else part
|
||||
|
||||
|
||||
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
||||
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
|
||||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
|
|
|
@ -11,6 +11,15 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
|
||||
def _lag_lead_sql(self, expression: exp.Lag | exp.Lead) -> str:
|
||||
return self.func(
|
||||
"LAG" if isinstance(expression, exp.Lag) else "LEAD",
|
||||
expression.this,
|
||||
expression.args.get("offset") or exp.Literal.number(1),
|
||||
expression.args.get("default") or exp.null(),
|
||||
)
|
||||
|
||||
|
||||
class Doris(MySQL):
|
||||
DATE_FORMAT = "'yyyy-MM-dd'"
|
||||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
|
@ -56,6 +65,8 @@ class Doris(MySQL):
|
|||
"GROUP_CONCAT", e.this, e.args.get("separator") or exp.Literal.string(",")
|
||||
),
|
||||
exp.JSONExtractScalar: lambda self, e: self.func("JSON_EXTRACT", e.this, e.expression),
|
||||
exp.Lag: _lag_lead_sql,
|
||||
exp.Lead: _lag_lead_sql,
|
||||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
|
|
|
@ -70,6 +70,9 @@ class Drill(Dialect):
|
|||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
KEYWORDS = tokens.Tokenizer.KEYWORDS.copy()
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
|
|
|
@ -15,11 +15,13 @@ from sqlglot.dialects.dialect import (
|
|||
build_default_decimal_type,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
no_datetime_sql,
|
||||
encode_decode_sql,
|
||||
build_formatted_time,
|
||||
inline_array_unless_query,
|
||||
no_comment_column_constraint_sql,
|
||||
no_safe_divide_sql,
|
||||
no_time_sql,
|
||||
no_timestamp_sql,
|
||||
pivot_column_names,
|
||||
regexp_extract_sql,
|
||||
|
@ -218,6 +220,7 @@ class DuckDB(Dialect):
|
|||
"TIMESTAMP_US": TokenType.TIMESTAMP,
|
||||
"VARCHAR": TokenType.TEXT,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -407,6 +410,7 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.Datetime: no_datetime_sql,
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
|
||||
|
@ -429,7 +433,6 @@ class DuckDB(Dialect):
|
|||
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
|
||||
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
|
||||
),
|
||||
exp.ParseJSON: rename_func("JSON"),
|
||||
exp.PercentileCont: rename_func("QUANTILE_CONT"),
|
||||
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
|
||||
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
|
||||
|
@ -450,13 +453,12 @@ class DuckDB(Dialect):
|
|||
exp.Split: rename_func("STR_SPLIT"),
|
||||
exp.SortArray: _sort_array_sql,
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
|
||||
exp.StrToTime: str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: self.func(
|
||||
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
|
||||
),
|
||||
exp.Struct: _struct_sql,
|
||||
exp.TimeAdd: _date_delta_sql,
|
||||
exp.Time: no_time_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
|
||||
|
@ -608,6 +610,24 @@ class DuckDB(Dialect):
|
|||
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
|
||||
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
|
||||
|
||||
def strtotime_sql(self, expression: exp.StrToTime) -> str:
|
||||
if expression.args.get("safe"):
|
||||
formatted_time = self.format_time(expression)
|
||||
return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS TIMESTAMP)"
|
||||
return str_to_time_sql(self, expression)
|
||||
|
||||
def strtodate_sql(self, expression: exp.StrToDate) -> str:
|
||||
if expression.args.get("safe"):
|
||||
formatted_time = self.format_time(expression)
|
||||
return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS DATE)"
|
||||
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
|
||||
|
||||
def parsejson_sql(self, expression: exp.ParseJSON) -> str:
|
||||
arg = expression.this
|
||||
if expression.args.get("safe"):
|
||||
return self.sql(exp.case().when(exp.func("json_valid", arg), arg).else_(exp.null()))
|
||||
return self.func("JSON", arg)
|
||||
|
||||
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
|
@ -728,3 +748,33 @@ class DuckDB(Dialect):
|
|||
this = self.sql(expression, "this").rstrip(")")
|
||||
|
||||
return f"{this}{expression_sql})"
|
||||
|
||||
def length_sql(self, expression: exp.Length) -> str:
|
||||
arg = expression.this
|
||||
|
||||
# Dialects like BQ and Snowflake also accept binary values as args, so
|
||||
# DDB will attempt to infer the type or resort to case/when resolution
|
||||
if not expression.args.get("binary") or arg.is_string:
|
||||
return self.func("LENGTH", arg)
|
||||
|
||||
if not arg.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
arg = annotate_types(arg)
|
||||
|
||||
if arg.is_type(*exp.DataType.TEXT_TYPES):
|
||||
return self.func("LENGTH", arg)
|
||||
|
||||
# We need these casts to make duckdb's static type checker happy
|
||||
blob = exp.cast(arg, exp.DataType.Type.VARBINARY)
|
||||
varchar = exp.cast(arg, exp.DataType.Type.VARCHAR)
|
||||
|
||||
case = (
|
||||
exp.case(self.func("TYPEOF", arg))
|
||||
.when(
|
||||
"'VARCHAR'", exp.Anonymous(this="LENGTH", expressions=[varchar])
|
||||
) # anonymous to break length_sql recursion
|
||||
.when("'BLOB'", self.func("OCTET_LENGTH", blob))
|
||||
)
|
||||
|
||||
return self.sql(case)
|
||||
|
|
|
@ -71,7 +71,7 @@ def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
|
|||
multiplier *= -1
|
||||
|
||||
if expression.expression.is_number:
|
||||
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
|
||||
modified_increment = exp.Literal.number(expression.expression.to_py() * multiplier)
|
||||
else:
|
||||
modified_increment = expression.expression
|
||||
if multiplier != 1:
|
||||
|
@ -446,12 +446,13 @@ class Hive(Dialect):
|
|||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"
|
||||
PARSE_JSON_NAME = None
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Insert,
|
||||
exp.Select,
|
||||
exp.Subquery,
|
||||
exp.Union,
|
||||
exp.SetOperation,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
|
@ -575,7 +576,6 @@ class Hive(Dialect):
|
|||
exp.NotForReplicationColumnConstraint: lambda *_: "",
|
||||
exp.OnProperty: lambda *_: "",
|
||||
exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY",
|
||||
exp.ParseJSON: lambda self, e: self.sql(e.this),
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
|
|
|
@ -689,6 +689,7 @@ class MySQL(Dialect):
|
|||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
PARSE_JSON_NAME = None
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -714,7 +715,6 @@ class MySQL(Dialect):
|
|||
exp.Month: _remove_ts_or_ds_to_date(),
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
|
||||
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
@ -1093,29 +1093,6 @@ class MySQL(Dialect):
|
|||
"xor",
|
||||
"year_month",
|
||||
"zerofill",
|
||||
"cume_dist",
|
||||
"dense_rank",
|
||||
"empty",
|
||||
"except",
|
||||
"first_value",
|
||||
"grouping",
|
||||
"groups",
|
||||
"intersect",
|
||||
"json_table",
|
||||
"lag",
|
||||
"last_value",
|
||||
"lateral",
|
||||
"lead",
|
||||
"nth_value",
|
||||
"ntile",
|
||||
"of",
|
||||
"over",
|
||||
"percent_rank",
|
||||
"rank",
|
||||
"recursive",
|
||||
"row_number",
|
||||
"system",
|
||||
"window",
|
||||
}
|
||||
|
||||
def array_sql(self, expression: exp.Array) -> str:
|
||||
|
|
|
@ -33,171 +33,11 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
|
|||
return exp.ToChar.from_arg_list(args)
|
||||
|
||||
|
||||
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
"""Remove join marks from an expression
|
||||
|
||||
SELECT * FROM a, b WHERE a.id = b.id(+)
|
||||
becomes:
|
||||
SELECT * FROM a LEFT JOIN b ON a.id = b.id
|
||||
|
||||
- for each scope
|
||||
- for each column with a join mark
|
||||
- find the predicate it belongs to
|
||||
- remove the predicate from the where clause
|
||||
- convert the predicate to a join with the (+) side as the left join table
|
||||
- replace the existing join with the new join
|
||||
|
||||
Args:
|
||||
ast: The AST to remove join marks from
|
||||
|
||||
Returns:
|
||||
The AST with join marks removed"""
|
||||
for scope in traverse_scope(ast):
|
||||
_eliminate_join_marks_from_scope(scope)
|
||||
return ast
|
||||
|
||||
|
||||
def _update_from(
|
||||
select: exp.Select,
|
||||
new_join_dict: t.Dict[str, exp.Join],
|
||||
old_join_dict: t.Dict[str, exp.Join],
|
||||
) -> None:
|
||||
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
|
||||
updates select in place
|
||||
|
||||
Args:
|
||||
select: The select statement to update
|
||||
new_join_dict: The dictionary of new joins
|
||||
old_join_dict: The dictionary of old joins
|
||||
"""
|
||||
old_from = select.args["from"]
|
||||
if old_from.alias_or_name not in new_join_dict:
|
||||
return
|
||||
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
|
||||
if len(in_old_not_new) >= 1:
|
||||
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
|
||||
new_from_this = old_join_dict[new_from_name].this
|
||||
new_from = exp.From(this=new_from_this)
|
||||
del old_join_dict[new_from_name]
|
||||
select.set("from", new_from)
|
||||
else:
|
||||
raise ValueError("Cannot determine which table to use as the new from")
|
||||
|
||||
|
||||
def _has_join_mark(col: exp.Expression) -> bool:
|
||||
"""Check if the column has a join mark
|
||||
|
||||
Args:
|
||||
The column to check
|
||||
"""
|
||||
return col.args.get("join_mark", False)
|
||||
|
||||
|
||||
def _predicate_to_join(
|
||||
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
|
||||
) -> t.Optional[exp.Join]:
|
||||
"""Convert an equality predicate to a join if it contains a join mark
|
||||
|
||||
Args:
|
||||
eq: The equality expression to convert to a join
|
||||
|
||||
Returns:
|
||||
The join expression if the equality contains a join mark (otherwise None)
|
||||
"""
|
||||
|
||||
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
|
||||
# return None
|
||||
|
||||
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
|
||||
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
|
||||
|
||||
left_has_join_mark = len(left_columns) > 0
|
||||
right_has_join_mark = len(right_columns) > 0
|
||||
|
||||
if left_has_join_mark:
|
||||
for col in left_columns:
|
||||
col.set("join_mark", False)
|
||||
join_on = col.table
|
||||
elif right_has_join_mark:
|
||||
for col in right_columns:
|
||||
col.set("join_mark", False)
|
||||
join_on = col.table
|
||||
else:
|
||||
return None
|
||||
|
||||
join_this = old_joins.get(join_on, old_from).this
|
||||
return exp.Join(this=join_this, on=eq, kind="LEFT")
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.optimizer.scope import Scope
|
||||
|
||||
|
||||
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
|
||||
"""Remove join marks columns in scope's where clause.
|
||||
Converts them to left joins and replaces any existing joins.
|
||||
Updates scope in place.
|
||||
|
||||
Args:
|
||||
scope: The scope to remove join marks from
|
||||
"""
|
||||
select_scope = scope.expression
|
||||
where = select_scope.args.get("where")
|
||||
joins = select_scope.args.get("joins")
|
||||
if not where:
|
||||
return
|
||||
if not joins:
|
||||
return
|
||||
|
||||
# dictionaries used to keep track of joins to be replaced
|
||||
old_joins = {join.alias_or_name: join for join in list(joins)}
|
||||
new_joins: t.Dict[str, exp.Join] = {}
|
||||
|
||||
for node in scope.find_all(exp.Column):
|
||||
if _has_join_mark(node):
|
||||
predicate = node.find_ancestor(exp.Predicate)
|
||||
if not isinstance(predicate, exp.Binary):
|
||||
continue
|
||||
predicate_parent = predicate.parent
|
||||
|
||||
join_on = predicate.pop()
|
||||
new_join = _predicate_to_join(
|
||||
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
|
||||
)
|
||||
# upsert new_join into new_joins dictionary
|
||||
if new_join:
|
||||
if new_join.alias_or_name in new_joins:
|
||||
new_joins[new_join.alias_or_name].set(
|
||||
"on",
|
||||
exp.and_(
|
||||
new_joins[new_join.alias_or_name].args["on"],
|
||||
new_join.args["on"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
new_joins[new_join.alias_or_name] = new_join
|
||||
# If the parent is a binary node with only one child, promote the child to the parent
|
||||
if predicate_parent:
|
||||
if isinstance(predicate_parent, exp.Binary):
|
||||
if predicate_parent.left is None:
|
||||
predicate_parent.replace(predicate_parent.right)
|
||||
elif predicate_parent.right is None:
|
||||
predicate_parent.replace(predicate_parent.left)
|
||||
|
||||
_update_from(select_scope, new_joins, old_joins)
|
||||
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
|
||||
select_scope.set("joins", replacement_joins)
|
||||
if not where.this:
|
||||
where.pop()
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
TABLESAMPLE_SIZE_IS_PERCENT = True
|
||||
SUPPORTS_COLUMN_JOIN_MARKS = True
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
|
||||
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
|
||||
|
@ -267,6 +107,7 @@ class Oracle(Dialect):
|
|||
"TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "oracle"),
|
||||
"TO_DATE": build_formatted_time(exp.StrToDate, "oracle"),
|
||||
}
|
||||
FUNCTIONS.pop("NVL")
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
|
@ -282,13 +123,6 @@ class Oracle(Dialect):
|
|||
"XMLTABLE": lambda self: self._parse_xml_table(),
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"CONNECT_BY_ROOT": lambda self: self.expression(
|
||||
exp.ConnectByRoot, this=self._parse_column()
|
||||
),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
"GLOBAL": lambda self: self._match_text_seq("TEMPORARY")
|
||||
|
@ -408,7 +242,6 @@ class Oracle(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
|
||||
exp.DateStrToDate: lambda self, e: self.func(
|
||||
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
|
||||
),
|
||||
|
|
|
@ -56,12 +56,15 @@ def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB]
|
|||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
|
||||
expression = self._simplify_unless_literal(expression.expression)
|
||||
if not isinstance(expression, exp.Literal):
|
||||
e = self._simplify_unless_literal(expression.expression)
|
||||
if isinstance(e, exp.Literal):
|
||||
e.args["is_string"] = True
|
||||
elif e.is_number:
|
||||
e = exp.Literal.string(e.to_py())
|
||||
else:
|
||||
self.unsupported("Cannot add non literal")
|
||||
|
||||
expression.args["is_string"] = True
|
||||
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
|
||||
return f"{this} {kind} {self.sql(exp.Interval(this=e, unit=unit))}"
|
||||
|
||||
return func
|
||||
|
||||
|
@ -331,6 +334,7 @@ class Postgres(Dialect):
|
|||
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
|
||||
"FLOAT": TokenType.DOUBLE,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
|
|
@ -173,6 +173,35 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
|
|||
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
|
||||
|
||||
|
||||
def _jsonextract_sql(self: Presto.Generator, expression: exp.JSONExtract) -> str:
|
||||
is_json_extract = self.dialect.settings.get("variant_extract_is_json_extract", True)
|
||||
|
||||
# Generate JSON_EXTRACT unless the user has configured that a Snowflake / Databricks
|
||||
# VARIANT extract (e.g. col:x.y) should map to dot notation (i.e ROW access) in Presto/Trino
|
||||
if not expression.args.get("variant_extract") or is_json_extract:
|
||||
return self.func(
|
||||
"JSON_EXTRACT", expression.this, expression.expression, *expression.expressions
|
||||
)
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
# Convert the JSONPath extraction `JSON_EXTRACT(col, '$.x.y) to a ROW access col.x.y
|
||||
segments = []
|
||||
for path_key in expression.expression.expressions[1:]:
|
||||
if not isinstance(path_key, exp.JSONPathKey):
|
||||
# Cannot transpile subscripts, wildcards etc to dot notation
|
||||
self.unsupported(f"Cannot transpile JSONPath segment '{path_key}' to ROW access")
|
||||
continue
|
||||
key = path_key.this
|
||||
if not exp.SAFE_IDENTIFIER_RE.match(key):
|
||||
key = f'"{key}"'
|
||||
segments.append(f".{key}")
|
||||
|
||||
expr = "".join(segments)
|
||||
|
||||
return f"{this}{expr}"
|
||||
|
||||
|
||||
def _to_int(expression: exp.Expression) -> exp.Expression:
|
||||
if not expression.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
@ -227,7 +256,7 @@ class Presto(Dialect):
|
|||
"TDIGEST": TokenType.TDIGEST,
|
||||
"HYPERLOGLOG": TokenType.HLLSKETCH,
|
||||
}
|
||||
|
||||
KEYWORDS.pop("/*+")
|
||||
KEYWORDS.pop("QUALIFY")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -305,6 +334,7 @@ class Presto(Dialect):
|
|||
MULTI_ARG_DISTINCT = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
HEX_FUNC = "TO_HEX"
|
||||
PARSE_JSON_NAME = "JSON_PARSE"
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -389,7 +419,7 @@ class Presto(Dialect):
|
|||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.JSONExtract: _jsonextract_sql,
|
||||
exp.Last: _first_last_sql,
|
||||
exp.LastValue: _first_last_sql,
|
||||
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
|
||||
|
@ -448,9 +478,6 @@ class Presto(Dialect):
|
|||
[transforms.remove_within_group_for_percentiles]
|
||||
),
|
||||
exp.Xor: bool_xor_sql,
|
||||
exp.MD5: lambda self, e: self.func(
|
||||
"LOWER", self.func("TO_HEX", self.func("MD5", self.sql(e, "this")))
|
||||
),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
exp.SHA: rename_func("SHA1"),
|
||||
exp.SHA2: sha256_sql,
|
||||
|
@ -517,6 +544,19 @@ class Presto(Dialect):
|
|||
"with",
|
||||
}
|
||||
|
||||
def md5_sql(self, expression: exp.MD5) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
this = annotate_types(this)
|
||||
|
||||
if this.is_type(*exp.DataType.TEXT_TYPES):
|
||||
this = exp.Encode(this=this, charset=exp.Literal.string("utf-8"))
|
||||
|
||||
return self.func("LOWER", self.func("TO_HEX", self.func("MD5", self.sql(this))))
|
||||
|
||||
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
|
||||
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
|
||||
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
json_extract_segments,
|
||||
no_tablesample_sql,
|
||||
rename_func,
|
||||
map_date_part,
|
||||
)
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -23,7 +24,11 @@ if t.TYPE_CHECKING:
|
|||
|
||||
def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
||||
def _builder(args: t.List) -> E:
|
||||
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
expr = expr_type(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=map_date_part(seq_get(args, 0)),
|
||||
)
|
||||
if expr_type is exp.TsOrDsAdd:
|
||||
expr.set("return_type", exp.DataType.build("TIMESTAMP"))
|
||||
|
||||
|
@ -40,7 +45,6 @@ class Redshift(Postgres):
|
|||
INDEX_OFFSET = 0
|
||||
COPY_PARAMS_ARE_CSV = False
|
||||
HEX_LOWERCASE = True
|
||||
SUPPORTS_COLUMN_JOIN_MARKS = True
|
||||
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
|
@ -148,6 +152,8 @@ class Redshift(Postgres):
|
|||
MULTI_ARG_DISTINCT = True
|
||||
COPY_PARAMS_ARE_WRAPPED = False
|
||||
HEX_FUNC = "TO_HEX"
|
||||
PARSE_JSON_NAME = "JSON_PARSE"
|
||||
|
||||
# Redshift doesn't have `WITH` as part of their with_properties so we remove it
|
||||
WITH_PROPERTIES_PREFIX = " "
|
||||
|
||||
|
@ -180,7 +186,6 @@ class Redshift(Postgres):
|
|||
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
|
@ -203,13 +208,14 @@ class Redshift(Postgres):
|
|||
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
|
||||
TRANSFORMS.pop(exp.Pivot)
|
||||
|
||||
# Postgres doesn't support JSON_PARSE, but Redshift does
|
||||
TRANSFORMS.pop(exp.ParseJSON)
|
||||
|
||||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
TRANSFORMS.pop(exp.Pow)
|
||||
|
||||
# Redshift supports ANY_VALUE(..)
|
||||
# Redshift supports these functions
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
|
||||
# Redshift supports LAST_DAY(..)
|
||||
TRANSFORMS.pop(exp.LastDay)
|
||||
TRANSFORMS.pop(exp.SHA2)
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
|
|||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
map_date_part,
|
||||
)
|
||||
from sqlglot.helper import flatten, is_float, is_int, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -75,7 +76,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
|
|||
|
||||
def _build_datediff(args: t.List) -> exp.DateDiff:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
|
||||
)
|
||||
|
||||
|
||||
|
@ -84,7 +85,7 @@ def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
|||
return expr_type(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=_map_date_part(seq_get(args, 0)),
|
||||
unit=map_date_part(seq_get(args, 0)),
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
@ -143,97 +144,9 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
|
|||
return _parse
|
||||
|
||||
|
||||
DATE_PART_MAPPING = {
|
||||
"Y": "YEAR",
|
||||
"YY": "YEAR",
|
||||
"YYY": "YEAR",
|
||||
"YYYY": "YEAR",
|
||||
"YR": "YEAR",
|
||||
"YEARS": "YEAR",
|
||||
"YRS": "YEAR",
|
||||
"MM": "MONTH",
|
||||
"MON": "MONTH",
|
||||
"MONS": "MONTH",
|
||||
"MONTHS": "MONTH",
|
||||
"D": "DAY",
|
||||
"DD": "DAY",
|
||||
"DAYS": "DAY",
|
||||
"DAYOFMONTH": "DAY",
|
||||
"WEEKDAY": "DAYOFWEEK",
|
||||
"DOW": "DAYOFWEEK",
|
||||
"DW": "DAYOFWEEK",
|
||||
"WEEKDAY_ISO": "DAYOFWEEKISO",
|
||||
"DOW_ISO": "DAYOFWEEKISO",
|
||||
"DW_ISO": "DAYOFWEEKISO",
|
||||
"YEARDAY": "DAYOFYEAR",
|
||||
"DOY": "DAYOFYEAR",
|
||||
"DY": "DAYOFYEAR",
|
||||
"W": "WEEK",
|
||||
"WK": "WEEK",
|
||||
"WEEKOFYEAR": "WEEK",
|
||||
"WOY": "WEEK",
|
||||
"WY": "WEEK",
|
||||
"WEEK_ISO": "WEEKISO",
|
||||
"WEEKOFYEARISO": "WEEKISO",
|
||||
"WEEKOFYEAR_ISO": "WEEKISO",
|
||||
"Q": "QUARTER",
|
||||
"QTR": "QUARTER",
|
||||
"QTRS": "QUARTER",
|
||||
"QUARTERS": "QUARTER",
|
||||
"H": "HOUR",
|
||||
"HH": "HOUR",
|
||||
"HR": "HOUR",
|
||||
"HOURS": "HOUR",
|
||||
"HRS": "HOUR",
|
||||
"M": "MINUTE",
|
||||
"MI": "MINUTE",
|
||||
"MIN": "MINUTE",
|
||||
"MINUTES": "MINUTE",
|
||||
"MINS": "MINUTE",
|
||||
"S": "SECOND",
|
||||
"SEC": "SECOND",
|
||||
"SECONDS": "SECOND",
|
||||
"SECS": "SECOND",
|
||||
"MS": "MILLISECOND",
|
||||
"MSEC": "MILLISECOND",
|
||||
"MILLISECONDS": "MILLISECOND",
|
||||
"US": "MICROSECOND",
|
||||
"USEC": "MICROSECOND",
|
||||
"MICROSECONDS": "MICROSECOND",
|
||||
"NS": "NANOSECOND",
|
||||
"NSEC": "NANOSECOND",
|
||||
"NANOSEC": "NANOSECOND",
|
||||
"NSECOND": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"NANOSECS": "NANOSECOND",
|
||||
"EPOCH": "EPOCH_SECOND",
|
||||
"EPOCH_SECONDS": "EPOCH_SECOND",
|
||||
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
|
||||
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
|
||||
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
|
||||
"TZH": "TIMEZONE_HOUR",
|
||||
"TZM": "TIMEZONE_MINUTE",
|
||||
}
|
||||
|
||||
|
||||
@t.overload
|
||||
def _map_date_part(part: exp.Expression) -> exp.Var:
|
||||
pass
|
||||
|
||||
|
||||
@t.overload
|
||||
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
pass
|
||||
|
||||
|
||||
def _map_date_part(part):
|
||||
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
|
||||
return exp.var(mapped) if mapped else part
|
||||
|
||||
|
||||
def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
trunc = date_trunc_to_time(args)
|
||||
trunc.set("unit", _map_date_part(trunc.args["unit"]))
|
||||
trunc.set("unit", map_date_part(trunc.args["unit"]))
|
||||
return trunc
|
||||
|
||||
|
||||
|
@ -328,7 +241,7 @@ class Snowflake(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
DEFAULT_SAMPLING_METHOD = "BERNOULLI"
|
||||
COLON_IS_JSON_EXTRACT = True
|
||||
COLON_IS_VARIANT_EXTRACT = True
|
||||
|
||||
ID_VAR_TOKENS = {
|
||||
*parser.Parser.ID_VAR_TOKENS,
|
||||
|
@ -367,8 +280,10 @@ class Snowflake(Dialect):
|
|||
),
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"LAST_DAY": lambda args: exp.LastDay(
|
||||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
|
||||
),
|
||||
"LEN": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
|
||||
"LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
|
@ -385,6 +300,7 @@ class Snowflake(Dialect):
|
|||
"TIMESTAMPDIFF": _build_datediff,
|
||||
"TIMESTAMPFROMPARTS": build_timestamp_from_parts,
|
||||
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
|
||||
"TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
|
||||
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
|
||||
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
|
||||
"TO_NUMBER": lambda args: exp.ToNumber(
|
||||
|
@ -541,7 +457,7 @@ class Snowflake(Dialect):
|
|||
|
||||
self._match(TokenType.COMMA)
|
||||
expression = self._parse_bitwise()
|
||||
this = _map_date_part(this)
|
||||
this = map_date_part(this)
|
||||
name = this.name.upper()
|
||||
|
||||
if name.startswith("EPOCH"):
|
||||
|
@ -588,10 +504,11 @@ class Snowflake(Dialect):
|
|||
|
||||
return lateral
|
||||
|
||||
def _parse_at_before(self, table: exp.Table) -> exp.Table:
|
||||
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
|
||||
index = self._index
|
||||
if self._match_texts(("AT", "BEFORE")):
|
||||
historical_data = None
|
||||
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
|
||||
this = self._prev.text.upper()
|
||||
kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
|
@ -602,14 +519,27 @@ class Snowflake(Dialect):
|
|||
|
||||
if expression:
|
||||
self._match_r_paren()
|
||||
when = self.expression(
|
||||
historical_data = self.expression(
|
||||
exp.HistoricalData, this=this, kind=kind, expression=expression
|
||||
)
|
||||
table.set("when", when)
|
||||
else:
|
||||
self._retreat(index)
|
||||
|
||||
return table
|
||||
return historical_data
|
||||
|
||||
def _parse_changes(self) -> t.Optional[exp.Changes]:
|
||||
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
|
||||
return None
|
||||
|
||||
information = self._parse_var(any_token=True)
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(
|
||||
exp.Changes,
|
||||
information=information,
|
||||
at_before=self._parse_historical_data(),
|
||||
end=self._parse_historical_data(),
|
||||
)
|
||||
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
|
||||
|
@ -643,7 +573,15 @@ class Snowflake(Dialect):
|
|||
else:
|
||||
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
|
||||
return self._parse_at_before(table)
|
||||
changes = self._parse_changes()
|
||||
if changes:
|
||||
table.set("changes", changes)
|
||||
|
||||
at_before = self._parse_historical_data()
|
||||
if at_before:
|
||||
table.set("when", at_before)
|
||||
|
||||
return table
|
||||
|
||||
def _parse_id_var(
|
||||
self,
|
||||
|
@ -771,6 +709,7 @@ class Snowflake(Dialect):
|
|||
"WAREHOUSE": TokenType.WAREHOUSE,
|
||||
"STREAMLIT": TokenType.STREAMLIT,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -839,6 +778,9 @@ class Snowflake(Dialect):
|
|||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.ParseJSON: lambda self, e: self.func(
|
||||
"TRY_PARSE_JSON" if e.args.get("safe") else "PARSE_JSON", e.this
|
||||
),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
|
|
|
@ -33,7 +33,7 @@ def _build_datediff(args: t.List) -> exp.Expression:
|
|||
expression = seq_get(args, 1)
|
||||
|
||||
if len(args) == 3:
|
||||
unit = this
|
||||
unit = exp.var(t.cast(exp.Expression, this).name)
|
||||
this = args[2]
|
||||
|
||||
return exp.DateDiff(
|
||||
|
@ -91,6 +91,8 @@ def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.Timestam
|
|||
|
||||
class Spark(Spark2):
|
||||
class Tokenizer(Spark2.Tokenizer):
|
||||
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
|
||||
|
||||
RAW_STRINGS = [
|
||||
(prefix + q, q)
|
||||
for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
|
||||
|
@ -105,6 +107,7 @@ class Spark(Spark2):
|
|||
"DATEADD": _build_dateadd,
|
||||
"TIMESTAMPADD": _build_dateadd,
|
||||
"DATEDIFF": _build_datediff,
|
||||
"DATE_DIFF": _build_datediff,
|
||||
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
|
||||
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
|
||||
"TRY_ELEMENT_AT": lambda args: exp.Bracket(
|
||||
|
|
|
@ -106,11 +106,16 @@ class SQLite(Dialect):
|
|||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = tokens.Tokenizer.KEYWORDS.copy()
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
"STRFTIME": _build_strftime,
|
||||
"DATETIME": lambda args: exp.Anonymous(this="DATETIME", expressions=args),
|
||||
"TIME": lambda args: exp.Anonymous(this="TIME", expressions=args),
|
||||
}
|
||||
STRING_ALIASES = True
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
to_number_with_nls_param,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -24,9 +25,9 @@ def _date_add_sql(
|
|||
if not isinstance(value, exp.Literal):
|
||||
self.unsupported("Cannot add non literal")
|
||||
|
||||
if value.is_negative:
|
||||
if isinstance(value, exp.Neg):
|
||||
kind_to_op = {"+": "-", "-": "+"}
|
||||
value = exp.Literal.string(value.name[1:])
|
||||
value = exp.Literal.string(value.this.to_py())
|
||||
else:
|
||||
kind_to_op = {"+": "+", "-": "-"}
|
||||
value.set("is_string", True)
|
||||
|
@ -96,6 +97,7 @@ class Teradata(Dialect):
|
|||
"TOP": TokenType.TOP,
|
||||
"UPD": TokenType.UPDATE,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
# Teradata does not support % as a modulo operator
|
||||
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
|
||||
|
@ -159,6 +161,11 @@ class Teradata(Dialect):
|
|||
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
|
||||
}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"RANDOM": lambda args: exp.Rand(lower=seq_get(args, 0), upper=seq_get(args, 1)),
|
||||
}
|
||||
|
||||
EXPONENT = {
|
||||
TokenType.DSTAR: exp.Pow,
|
||||
}
|
||||
|
@ -200,6 +207,14 @@ class Teradata(Dialect):
|
|||
|
||||
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
|
||||
|
||||
def _parse_index_params(self) -> exp.IndexParameters:
|
||||
this = super()._parse_index_params()
|
||||
|
||||
if this.args.get("on"):
|
||||
this.set("on", None)
|
||||
self._retreat(self._index - 2)
|
||||
return this
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_IS_TOP = True
|
||||
JOIN_HINTS = False
|
||||
|
@ -208,11 +223,13 @@ class Teradata(Dialect):
|
|||
TABLESAMPLE_KEYWORDS = "SAMPLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
CAN_IMPLEMENT_ARRAY_ANY = True
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -230,6 +247,7 @@ class Teradata(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Pow: lambda self, e: self.binary(e, "**"),
|
||||
exp.Rand: lambda self, e: self.func("RANDOM", e.args.get("lower"), e.args.get("upper")),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
|
@ -238,12 +256,15 @@ class Teradata(Dialect):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.ToNumber: to_number_with_nls_param,
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)),
|
||||
}
|
||||
|
||||
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
|
||||
prefix, suffix = ("(", ")") if expression.this else ("", "")
|
||||
return self.func("CURRENT_TIMESTAMP", expression.this, prefix=prefix, suffix=suffix)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"):
|
||||
# We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>)
|
||||
|
|
|
@ -450,6 +450,7 @@ class TSQL(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"CLUSTERED INDEX": TokenType.INDEX,
|
||||
"DATETIME2": TokenType.DATETIME,
|
||||
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
|
||||
"DECLARE": TokenType.DECLARE,
|
||||
|
@ -457,6 +458,7 @@ class TSQL(Dialect):
|
|||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
"IMAGE": TokenType.IMAGE,
|
||||
"MONEY": TokenType.MONEY,
|
||||
"NONCLUSTERED INDEX": TokenType.INDEX,
|
||||
"NTEXT": TokenType.TEXT,
|
||||
"OPTION": TokenType.OPTION,
|
||||
"OUTPUT": TokenType.RETURNING,
|
||||
|
@ -475,6 +477,7 @@ class TSQL(Dialect):
|
|||
"UPDATE STATISTICS": TokenType.COMMAND,
|
||||
"XML": TokenType.XML,
|
||||
}
|
||||
KEYWORDS.pop("/*+")
|
||||
|
||||
COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END}
|
||||
|
||||
|
@ -533,6 +536,31 @@ class TSQL(Dialect):
|
|||
TokenType.DECLARE: lambda self: self._parse_declare(),
|
||||
}
|
||||
|
||||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.DCOLON: lambda self, this: self.expression(
|
||||
exp.ScopeResolution,
|
||||
this=this,
|
||||
expression=self._parse_function() or self._parse_var(any_token=True),
|
||||
),
|
||||
}
|
||||
|
||||
# The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL
|
||||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.DCOLON: lambda self, this, to: self.expression(exp.Cast, this=this, to=to)
|
||||
if isinstance(to, exp.DataType) and to.this != exp.DataType.Type.USERDEFINED
|
||||
else self.expression(exp.ScopeResolution, this=this, expression=to),
|
||||
}
|
||||
|
||||
def _parse_dcolon(self) -> t.Optional[exp.Expression]:
|
||||
# We want to use _parse_types() if the first token after :: is a known type,
|
||||
# otherwise we could parse something like x::varchar(max) into a function
|
||||
if self._match_set(self.TYPE_TOKENS, advance=False):
|
||||
return self._parse_types()
|
||||
|
||||
return self._parse_function() or self._parse_types()
|
||||
|
||||
def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.OPTION):
|
||||
return None
|
||||
|
@ -757,12 +785,15 @@ class TSQL(Dialect):
|
|||
SUPPORTS_SELECT_INTO = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_TO_NUMBER = False
|
||||
OUTER_UNION_MODIFIERS = False
|
||||
SET_OP_MODIFIERS = False
|
||||
COPY_PARAMS_EQ_REQUIRED = True
|
||||
PARSE_JSON_NAME = None
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
exp.Insert,
|
||||
exp.Intersect,
|
||||
exp.Except,
|
||||
exp.Merge,
|
||||
exp.Select,
|
||||
exp.Subquery,
|
||||
|
@ -816,7 +847,6 @@ class TSQL(Dialect):
|
|||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
|
||||
exp.Repeat: rename_func("REPLICATE"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
@ -850,6 +880,9 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def scope_resolution(self, rhs: str, scope_name: str) -> str:
|
||||
return f"{scope_name}::{rhs}"
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
if expression.args.get("offset"):
|
||||
if not expression.args.get("order"):
|
||||
|
|
|
@ -447,6 +447,7 @@ class Python(Dialect):
|
|||
exp.Is: lambda self, e: (
|
||||
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
|
||||
),
|
||||
exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions),
|
||||
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
|
||||
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
|
||||
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
|
||||
|
|
|
@ -20,6 +20,7 @@ import textwrap
|
|||
import typing as t
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from decimal import Decimal
|
||||
from enum import auto
|
||||
from functools import reduce
|
||||
|
||||
|
@ -29,7 +30,6 @@ from sqlglot.helper import (
|
|||
camel_to_snake_case,
|
||||
ensure_collection,
|
||||
ensure_list,
|
||||
is_int,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
Q = t.TypeVar("Q", bound="Query")
|
||||
S = t.TypeVar("S", bound="SetOperation")
|
||||
|
||||
|
||||
class _Expression(type):
|
||||
|
@ -174,23 +175,22 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
Checks whether a Literal expression is a number.
|
||||
"""
|
||||
return isinstance(self, Literal) and not self.args["is_string"]
|
||||
return (isinstance(self, Literal) and not self.args["is_string"]) or (
|
||||
isinstance(self, Neg) and self.this.is_number
|
||||
)
|
||||
|
||||
@property
|
||||
def is_negative(self) -> bool:
|
||||
def to_py(self) -> t.Any:
|
||||
"""
|
||||
Checks whether an expression is negative.
|
||||
|
||||
Handles both exp.Neg and Literal numbers with "-" which come from optimizer.simplify.
|
||||
Returns a Python object equivalent of the SQL node.
|
||||
"""
|
||||
return isinstance(self, Neg) or (self.is_number and self.this.startswith("-"))
|
||||
raise ValueError(f"{self} cannot be converted to a Python object.")
|
||||
|
||||
@property
|
||||
def is_int(self) -> bool:
|
||||
"""
|
||||
Checks whether a Literal expression is an integer.
|
||||
Checks whether an expression is an integer.
|
||||
"""
|
||||
return self.is_number and is_int(self.name)
|
||||
return self.is_number and isinstance(self.to_py(), int)
|
||||
|
||||
@property
|
||||
def is_star(self) -> bool:
|
||||
|
@ -2002,6 +2002,10 @@ class Check(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class Changes(Expression):
|
||||
arg_types = {"information": True, "at_before": False, "end": False}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
|
||||
class Connect(Expression):
|
||||
arg_types = {"start": False, "connect": True, "nocycle": False}
|
||||
|
@ -2127,6 +2131,7 @@ class IndexParameters(Expression):
|
|||
"partition_by": False,
|
||||
"tablespace": False,
|
||||
"where": False,
|
||||
"on": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -2281,6 +2286,14 @@ class Literal(Condition):
|
|||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
def to_py(self) -> int | str | Decimal:
|
||||
if self.is_number:
|
||||
try:
|
||||
return int(self.this)
|
||||
except ValueError:
|
||||
return Decimal(self.this)
|
||||
return self.this
|
||||
|
||||
|
||||
class Join(Expression):
|
||||
arg_types = {
|
||||
|
@ -2639,6 +2652,10 @@ class DictRange(Property):
|
|||
arg_types = {"this": True, "min": True, "max": True}
|
||||
|
||||
|
||||
class DynamicProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
# Clickhouse CREATE ... ON CLUSTER modifier
|
||||
# https://clickhouse.com/docs/en/sql-reference/distributed-ddl
|
||||
class OnCluster(Property):
|
||||
|
@ -2805,6 +2822,10 @@ class TemporaryProperty(Property):
|
|||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class SecureProperty(Property):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class TransformModelProperty(Property):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
@ -2834,6 +2855,10 @@ class WithJournalTableProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class WithSchemaBindingProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class WithSystemVersioningProperty(Property):
|
||||
arg_types = {
|
||||
"on": False,
|
||||
|
@ -3017,6 +3042,7 @@ class Table(Expression):
|
|||
"when": False,
|
||||
"only": False,
|
||||
"partition": False,
|
||||
"changes": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -3065,7 +3091,7 @@ class Table(Expression):
|
|||
return col
|
||||
|
||||
|
||||
class Union(Query):
|
||||
class SetOperation(Query):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
|
@ -3076,13 +3102,13 @@ class Union(Query):
|
|||
}
|
||||
|
||||
def select(
|
||||
self,
|
||||
self: S,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Union:
|
||||
) -> S:
|
||||
this = maybe_copy(self, copy)
|
||||
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
|
||||
this.expression.unnest().select(
|
||||
|
@ -3111,11 +3137,15 @@ class Union(Query):
|
|||
return self.expression
|
||||
|
||||
|
||||
class Except(Union):
|
||||
class Union(SetOperation):
|
||||
pass
|
||||
|
||||
|
||||
class Intersect(Union):
|
||||
class Except(SetOperation):
|
||||
pass
|
||||
|
||||
|
||||
class Intersect(SetOperation):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -3727,7 +3757,7 @@ class Select(Query):
|
|||
return self.expressions
|
||||
|
||||
|
||||
UNWRAPPED_QUERIES = (Select, Union)
|
||||
UNWRAPPED_QUERIES = (Select, SetOperation)
|
||||
|
||||
|
||||
class Subquery(DerivedTable, Query):
|
||||
|
@ -3893,9 +3923,13 @@ class Null(Condition):
|
|||
def name(self) -> str:
|
||||
return "NULL"
|
||||
|
||||
def to_py(self) -> Lit[None]:
|
||||
return None
|
||||
|
||||
|
||||
class Boolean(Condition):
|
||||
pass
|
||||
def to_py(self) -> bool:
|
||||
return self.this
|
||||
|
||||
|
||||
class DataTypeParam(Expression):
|
||||
|
@ -4019,6 +4053,7 @@ class DataType(Expression):
|
|||
VARBINARY = auto()
|
||||
VARCHAR = auto()
|
||||
VARIANT = auto()
|
||||
VECTOR = auto()
|
||||
XML = auto()
|
||||
YEAR = auto()
|
||||
TDIGEST = auto()
|
||||
|
@ -4473,7 +4508,10 @@ class Paren(Unary):
|
|||
|
||||
|
||||
class Neg(Unary):
|
||||
pass
|
||||
def to_py(self) -> int | Decimal:
|
||||
if self.is_number:
|
||||
return self.this.to_py() * -1
|
||||
return super().to_py()
|
||||
|
||||
|
||||
class Alias(Expression):
|
||||
|
@ -5065,6 +5103,12 @@ class DateTrunc(Func):
|
|||
return self.args["unit"]
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/datetime_functions#datetime
|
||||
# expression can either be time_expr or time_zone
|
||||
class Datetime(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class DatetimeAdd(Func, IntervalOp):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
@ -5115,7 +5159,7 @@ class Extract(Func):
|
|||
|
||||
|
||||
class Timestamp(Func):
|
||||
arg_types = {"this": False, "expression": False, "with_tz": False}
|
||||
arg_types = {"this": False, "zone": False, "with_tz": False}
|
||||
|
||||
|
||||
class TimestampAdd(Func, TimeUnit):
|
||||
|
@ -5441,12 +5485,18 @@ class OpenJSON(Func):
|
|||
arg_types = {"this": True, "path": False, "expressions": False}
|
||||
|
||||
|
||||
class JSONBContains(Binary):
|
||||
class JSONBContains(Binary, Func):
|
||||
_sql_names = ["JSONB_CONTAINS"]
|
||||
|
||||
|
||||
class JSONExtract(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False}
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"expression": True,
|
||||
"only_json_types": False,
|
||||
"expressions": False,
|
||||
"variant_extract": False,
|
||||
}
|
||||
_sql_names = ["JSON_EXTRACT"]
|
||||
is_var_len_args = True
|
||||
|
||||
|
@ -5485,9 +5535,9 @@ class JSONArrayContains(Binary, Predicate, Func):
|
|||
|
||||
class ParseJSON(Func):
|
||||
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
|
||||
# Snowflake also has TRY_PARSE_JSON, which is represented using `safe`
|
||||
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
arg_types = {"this": True, "expression": False, "safe": False}
|
||||
|
||||
|
||||
class Least(Func):
|
||||
|
@ -5504,6 +5554,7 @@ class Right(Func):
|
|||
|
||||
|
||||
class Length(Func):
|
||||
arg_types = {"this": True, "binary": False}
|
||||
_sql_names = ["LENGTH", "LEN"]
|
||||
|
||||
|
||||
|
@ -5560,6 +5611,11 @@ class MapFromEntries(Func):
|
|||
pass
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/language-elements/scope-resolution-operator-transact-sql?view=sql-server-ver16
|
||||
class ScopeResolution(Expression):
|
||||
arg_types = {"this": False, "expression": True}
|
||||
|
||||
|
||||
class StarMap(Func):
|
||||
pass
|
||||
|
||||
|
@ -5642,9 +5698,11 @@ class Quarter(Func):
|
|||
pass
|
||||
|
||||
|
||||
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions/RANDOM/RANDOM-Function-Syntax
|
||||
# teradata lower and upper bounds
|
||||
class Rand(Func):
|
||||
_sql_names = ["RAND", "RANDOM"]
|
||||
arg_types = {"this": False}
|
||||
arg_types = {"this": False, "lower": False, "upper": False}
|
||||
|
||||
|
||||
class Randn(Func):
|
||||
|
@ -5765,11 +5823,11 @@ class StrPosition(Func):
|
|||
|
||||
|
||||
class StrToDate(Func):
|
||||
arg_types = {"this": True, "format": False}
|
||||
arg_types = {"this": True, "format": False, "safe": False}
|
||||
|
||||
|
||||
class StrToTime(Func):
|
||||
arg_types = {"this": True, "format": True, "zone": False}
|
||||
arg_types = {"this": True, "format": True, "zone": False, "safe": False}
|
||||
|
||||
|
||||
# Spark allows unix_timestamp()
|
||||
|
@ -5833,6 +5891,11 @@ class StddevSamp(AggFunc):
|
|||
pass
|
||||
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/time_functions#time
|
||||
class Time(Func):
|
||||
arg_types = {"this": False, "zone": False}
|
||||
|
||||
|
||||
class TimeToStr(Func):
|
||||
arg_types = {"this": True, "format": True, "culture": False, "timezone": False}
|
||||
|
||||
|
|
|
@ -87,9 +87,11 @@ class Generator(metaclass=_Generator):
|
|||
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.DynamicProperty: lambda *_: "DYNAMIC",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
exp.EphemeralColumnConstraint: lambda self,
|
||||
e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}",
|
||||
|
@ -131,6 +133,7 @@ class Generator(metaclass=_Generator):
|
|||
"RETURNS NULL ON NULL INPUT" if e.args.get("null") else self.naked_property(e)
|
||||
),
|
||||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SecureProperty: lambda *_: "SECURE",
|
||||
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
|
||||
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
|
@ -143,7 +146,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.TemporaryProperty: lambda *_: "TEMPORARY",
|
||||
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
|
||||
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.args.get("zone")),
|
||||
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
|
@ -154,6 +157,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
|
||||
exp.VolatileProperty: lambda *_: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
|
||||
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
|
||||
}
|
||||
|
||||
|
@ -168,8 +172,8 @@ class Generator(metaclass=_Generator):
|
|||
# Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
|
||||
# Always do union distinct or union all
|
||||
EXPLICIT_UNION = False
|
||||
# Always do <set op> distinct or <set op> all
|
||||
EXPLICIT_SET_OP = False
|
||||
|
||||
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||
WRAP_DERIVED_VALUES = True
|
||||
|
@ -339,10 +343,10 @@ class Generator(metaclass=_Generator):
|
|||
# Whether the function TO_NUMBER is supported
|
||||
SUPPORTS_TO_NUMBER = True
|
||||
|
||||
# Whether or not union modifiers apply to the outer union or select.
|
||||
# Whether or not set op modifiers apply to the outer set op or select.
|
||||
# SELECT * FROM x UNION SELECT * FROM y LIMIT 1
|
||||
# True means limit 1 happens after the union, False means it it happens on y.
|
||||
OUTER_UNION_MODIFIERS = True
|
||||
# True means limit 1 happens after the set op, False means it it happens on y.
|
||||
SET_OP_MODIFIERS = True
|
||||
|
||||
# Whether parameters from COPY statement are wrapped in parentheses
|
||||
COPY_PARAMS_ARE_WRAPPED = True
|
||||
|
@ -368,6 +372,12 @@ class Generator(metaclass=_Generator):
|
|||
# The keywords to use when prefixing & separating WITH based properties
|
||||
WITH_PROPERTIES_PREFIX = "WITH"
|
||||
|
||||
# Whether to quote the generated expression of exp.JsonPath
|
||||
QUOTE_JSON_PATH = True
|
||||
|
||||
# The name to generate for the JSONPath expression. If `None`, only `this` will be generated
|
||||
PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -430,6 +440,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DynamicProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -469,6 +480,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SampleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SecureProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Set: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -491,6 +503,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
|
@ -506,7 +519,7 @@ class Generator(metaclass=_Generator):
|
|||
exp.Insert,
|
||||
exp.Join,
|
||||
exp.Select,
|
||||
exp.Union,
|
||||
exp.SetOperation,
|
||||
exp.Update,
|
||||
exp.Where,
|
||||
exp.With,
|
||||
|
@ -515,7 +528,7 @@ class Generator(metaclass=_Generator):
|
|||
# Expressions that should not have their comments generated in maybe_comment
|
||||
EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Binary,
|
||||
exp.Union,
|
||||
exp.SetOperation,
|
||||
)
|
||||
|
||||
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
|
||||
|
@ -1298,8 +1311,10 @@ class Generator(metaclass=_Generator):
|
|||
with_storage = f" WITH ({with_storage})" if with_storage else ""
|
||||
tablespace = self.sql(expression, "tablespace")
|
||||
tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""
|
||||
on = self.sql(expression, "on")
|
||||
on = f" ON {on}" if on else ""
|
||||
|
||||
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"
|
||||
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}{on}"
|
||||
|
||||
def index_sql(self, expression: exp.Index) -> str:
|
||||
unique = "UNIQUE " if expression.args.get("unique") else ""
|
||||
|
@ -1736,7 +1751,10 @@ class Generator(metaclass=_Generator):
|
|||
if when:
|
||||
table = f"{table} {when}"
|
||||
|
||||
return f"{only}{table}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
changes = self.sql(expression, "changes")
|
||||
changes = f" {changes}" if changes else ""
|
||||
|
||||
return f"{only}{table}{changes}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
self,
|
||||
|
@ -2393,8 +2411,8 @@ class Generator(metaclass=_Generator):
|
|||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
|
||||
|
||||
def set_operations(self, expression: exp.Union) -> str:
|
||||
if not self.OUTER_UNION_MODIFIERS:
|
||||
def set_operations(self, expression: exp.SetOperation) -> str:
|
||||
if not self.SET_OP_MODIFIERS:
|
||||
limit = expression.args.get("limit")
|
||||
order = expression.args.get("order")
|
||||
|
||||
|
@ -2413,7 +2431,7 @@ class Generator(metaclass=_Generator):
|
|||
while stack:
|
||||
node = stack.pop()
|
||||
|
||||
if isinstance(node, exp.Union):
|
||||
if isinstance(node, exp.SetOperation):
|
||||
stack.append(node.expression)
|
||||
stack.append(
|
||||
self.maybe_comment(
|
||||
|
@ -2433,8 +2451,8 @@ class Generator(metaclass=_Generator):
|
|||
def union_sql(self, expression: exp.Union) -> str:
|
||||
return self.set_operations(expression)
|
||||
|
||||
def union_op(self, expression: exp.Union) -> str:
|
||||
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
|
||||
def union_op(self, expression: exp.SetOperation) -> str:
|
||||
kind = " DISTINCT" if self.EXPLICIT_SET_OP else ""
|
||||
kind = kind if expression.args.get("distinct") else " ALL"
|
||||
by_name = " BY NAME" if expression.args.get("by_name") else ""
|
||||
return f"UNION{kind}{by_name}"
|
||||
|
@ -2653,7 +2671,10 @@ class Generator(metaclass=_Generator):
|
|||
|
||||
def jsonpath_sql(self, expression: exp.JSONPath) -> str:
|
||||
path = self.expressions(expression, sep="", flat=True).lstrip(".")
|
||||
return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
|
||||
if self.QUOTE_JSON_PATH:
|
||||
path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
|
||||
|
||||
return path
|
||||
|
||||
def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
|
||||
if isinstance(expression, exp.JSONPathPart):
|
||||
|
@ -3969,3 +3990,56 @@ class Generator(metaclass=_Generator):
|
|||
this = self.sql(expression, "this")
|
||||
this = f"TABLE {this}"
|
||||
return self.func("GAP_FILL", this, *[v for k, v in expression.args.items() if k != "this"])
|
||||
|
||||
def scope_resolution(self, rhs: str, scope_name: str) -> str:
|
||||
return self.func("SCOPE_RESOLUTION", scope_name or None, rhs)
|
||||
|
||||
def scoperesolution_sql(self, expression: exp.ScopeResolution) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expr = expression.expression
|
||||
|
||||
if isinstance(expr, exp.Func):
|
||||
# T-SQL's CLR functions are case sensitive
|
||||
expr = f"{self.sql(expr, 'this')}({self.format_args(*expr.expressions)})"
|
||||
else:
|
||||
expr = self.sql(expression, "expression")
|
||||
|
||||
return self.scope_resolution(expr, this)
|
||||
|
||||
def parsejson_sql(self, expression: exp.ParseJSON) -> str:
|
||||
if self.PARSE_JSON_NAME is None:
|
||||
return self.sql(expression.this)
|
||||
|
||||
return self.func(self.PARSE_JSON_NAME, expression.this, expression.expression)
|
||||
|
||||
def length_sql(self, expression: exp.Length) -> str:
|
||||
return self.func("LENGTH", expression.this)
|
||||
|
||||
def rand_sql(self, expression: exp.Rand) -> str:
|
||||
lower = self.sql(expression, "lower")
|
||||
upper = self.sql(expression, "upper")
|
||||
|
||||
if lower and upper:
|
||||
return f"({upper} - {lower}) * {self.func('RAND', expression.this)} + {lower}"
|
||||
return self.func("RAND", expression.this)
|
||||
|
||||
def strtodate_sql(self, expression: exp.StrToDate) -> str:
|
||||
return self.func("STR_TO_DATE", expression.this, expression.args.get("format"))
|
||||
|
||||
def strtotime_sql(self, expression: exp.StrToTime) -> str:
|
||||
return self.func(
|
||||
"STR_TO_TIME",
|
||||
expression.this,
|
||||
expression.args.get("format"),
|
||||
expression.args.get("zone"),
|
||||
)
|
||||
|
||||
def changes_sql(self, expression: exp.Changes) -> str:
|
||||
information = self.sql(expression, "information")
|
||||
information = f"INFORMATION => {information}"
|
||||
at_before = self.sql(expression, "at_before")
|
||||
at_before = f"{self.seg('')}{at_before}" if at_before else ""
|
||||
end = self.sql(expression, "end")
|
||||
end = f"{self.seg('')}{end}" if end else ""
|
||||
|
||||
return f"CHANGES ({information}){at_before}{end}"
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
|
|||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import Lit
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
class JSONPathTokenizer(Tokenizer):
|
||||
|
@ -36,9 +37,12 @@ class JSONPathTokenizer(Tokenizer):
|
|||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
|
||||
def parse(path: str) -> exp.JSONPath:
|
||||
def parse(path: str, dialect: DialectType = None) -> exp.JSONPath:
|
||||
"""Takes in a JSON path string and parses it into a JSONPath expression."""
|
||||
tokens = JSONPathTokenizer().tokenize(path)
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
jsonpath_tokenizer = Dialect.get_or_raise(dialect).jsonpath_tokenizer
|
||||
tokens = jsonpath_tokenizer.tokenize(path)
|
||||
size = len(tokens)
|
||||
|
||||
i = 0
|
||||
|
|
|
@ -152,8 +152,9 @@ def to_node(
|
|||
reference_node_name=reference_node_name,
|
||||
trim_selects=trim_selects,
|
||||
)
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
|
||||
if isinstance(scope.expression, exp.SetOperation):
|
||||
name = type(scope.expression).__name__.upper()
|
||||
upstream = upstream or Node(name=name, source=scope.expression, expression=select)
|
||||
|
||||
index = (
|
||||
column
|
||||
|
|
|
@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.Datetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
|
@ -196,6 +197,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.JSON: {
|
||||
exp.ParseJSON,
|
||||
},
|
||||
exp.DataType.Type.TIME: {
|
||||
exp.Time,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
|
|
|
@ -42,7 +42,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
|
|||
and not node.args.get("zone")
|
||||
):
|
||||
return exp.cast(node.this, to=exp.DataType.Type.DATE)
|
||||
if isinstance(node, exp.Timestamp) and not node.expression:
|
||||
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
|
||||
if not node.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
|
|||
if scope.expression.args.get("distinct"):
|
||||
parent_selections = {SELECT_ALL}
|
||||
|
||||
if isinstance(scope.expression, exp.Union):
|
||||
if isinstance(scope.expression, exp.SetOperation):
|
||||
left, right = scope.union_scopes
|
||||
referenced_columns[left] = parent_selections
|
||||
|
||||
|
|
|
@ -60,8 +60,12 @@ def qualify_columns(
|
|||
_pop_table_column_aliases(scope.derived_tables)
|
||||
using_column_tables = _expand_using(scope, resolver)
|
||||
|
||||
if schema.empty and expand_alias_refs:
|
||||
_expand_alias_refs(scope, resolver)
|
||||
if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
|
||||
_expand_alias_refs(
|
||||
scope,
|
||||
resolver,
|
||||
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
|
||||
)
|
||||
|
||||
_convert_columns_to_dots(scope, resolver)
|
||||
_qualify_columns(scope, resolver)
|
||||
|
@ -148,7 +152,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
|
||||
|
||||
for join in joins:
|
||||
for i, join in enumerate(joins):
|
||||
using = join.args.get("using")
|
||||
|
||||
if not using:
|
||||
|
@ -168,6 +172,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
ordered.append(join_table)
|
||||
join_columns = resolver.get_source_columns(join_table)
|
||||
conditions = []
|
||||
using_identifier_count = len(using)
|
||||
|
||||
for identifier in using:
|
||||
identifier = identifier.name
|
||||
|
@ -178,9 +183,21 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
raise OptimizeError(f"Cannot automatically join: {identifier}")
|
||||
|
||||
table = table or source_table
|
||||
conditions.append(
|
||||
exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
|
||||
)
|
||||
|
||||
if i == 0 or using_identifier_count == 1:
|
||||
lhs: exp.Expression = exp.column(identifier, table=table)
|
||||
else:
|
||||
coalesce_columns = [
|
||||
exp.column(identifier, table=t)
|
||||
for t in ordered[:-1]
|
||||
if identifier in resolver.get_source_columns(t)
|
||||
]
|
||||
if len(coalesce_columns) > 1:
|
||||
lhs = exp.func("coalesce", *coalesce_columns)
|
||||
else:
|
||||
lhs = exp.column(identifier, table=table)
|
||||
|
||||
conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
|
||||
|
||||
# Set all values in the dict to None, because we only care about the key ordering
|
||||
tables = column_tables.setdefault(identifier, {})
|
||||
|
@ -196,8 +213,8 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
for column in scope.columns:
|
||||
if not column.table and column.name in column_tables:
|
||||
tables = column_tables[column.name]
|
||||
coalesce = [exp.column(column.name, table=table) for table in tables]
|
||||
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
|
||||
coalesce_args = [exp.column(column.name, table=table) for table in tables]
|
||||
replacement = exp.func("coalesce", *coalesce_args)
|
||||
|
||||
# Ensure selects keep their output name
|
||||
if isinstance(column.parent, exp.Select):
|
||||
|
@ -208,7 +225,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
|
|||
return column_tables
|
||||
|
||||
|
||||
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
||||
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
|
||||
expression = scope.expression
|
||||
|
||||
if not isinstance(expression, exp.Select):
|
||||
|
@ -219,7 +236,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
|
|||
def replace_columns(
|
||||
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
|
||||
) -> None:
|
||||
if not node:
|
||||
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
|
||||
return
|
||||
|
||||
for column in walk_in_scope(node, prune=lambda node: node.is_star):
|
||||
|
@ -583,14 +600,10 @@ def _expand_stars(
|
|||
if name in using_column_tables and table in using_column_tables[name]:
|
||||
coalesced_columns.add(name)
|
||||
tables = using_column_tables[name]
|
||||
coalesce = [exp.column(name, table=table) for table in tables]
|
||||
coalesce_args = [exp.column(name, table=table) for table in tables]
|
||||
|
||||
new_selections.append(
|
||||
alias(
|
||||
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
|
||||
alias=name,
|
||||
copy=False,
|
||||
)
|
||||
alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
|
||||
)
|
||||
else:
|
||||
alias_ = replace_columns.get(table_id, {}).get(name, name)
|
||||
|
@ -719,6 +732,7 @@ class Resolver:
|
|||
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
|
||||
self._all_columns: t.Optional[t.Set[str]] = None
|
||||
self._infer_schema = infer_schema
|
||||
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
|
||||
|
||||
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
|
||||
"""
|
||||
|
@ -771,41 +785,49 @@ class Resolver:
|
|||
|
||||
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
|
||||
"""Resolve the source columns for a given source `name`."""
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
cache_key = (name, only_visible)
|
||||
if cache_key not in self._get_source_columns_cache:
|
||||
if name not in self.scope.sources:
|
||||
raise OptimizeError(f"Unknown table: {name}")
|
||||
|
||||
source = self.scope.sources[name]
|
||||
source = self.scope.sources[name]
|
||||
|
||||
if isinstance(source, exp.Table):
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
|
||||
columns = source.expression.named_selects
|
||||
if isinstance(source, exp.Table):
|
||||
columns = self.schema.column_names(source, only_visible)
|
||||
elif isinstance(source, Scope) and isinstance(
|
||||
source.expression, (exp.Values, exp.Unnest)
|
||||
):
|
||||
columns = source.expression.named_selects
|
||||
|
||||
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
||||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
# in bigquery, unnest structs are automatically scoped as tables, so you can
|
||||
# directly select a struct field in a query.
|
||||
# this handles the case where the unnest is statically defined.
|
||||
if self.schema.dialect == "bigquery":
|
||||
if source.expression.is_type(exp.DataType.Type.STRUCT):
|
||||
for k in source.expression.type.expressions: # type: ignore
|
||||
columns.append(k.name)
|
||||
else:
|
||||
columns = source.expression.named_selects
|
||||
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
column_aliases = node.expression.alias_column_names
|
||||
elif isinstance(node, exp.Expression):
|
||||
column_aliases = node.alias_column_names
|
||||
else:
|
||||
column_aliases = []
|
||||
node, _ = self.scope.selected_sources.get(name) or (None, None)
|
||||
if isinstance(node, Scope):
|
||||
column_aliases = node.expression.alias_column_names
|
||||
elif isinstance(node, exp.Expression):
|
||||
column_aliases = node.alias_column_names
|
||||
else:
|
||||
column_aliases = []
|
||||
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
return [
|
||||
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
return columns
|
||||
if column_aliases:
|
||||
# If the source's columns are aliased, their aliases shadow the corresponding column names.
|
||||
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
|
||||
columns = [
|
||||
alias or name
|
||||
for (name, alias) in itertools.zip_longest(columns, column_aliases)
|
||||
]
|
||||
|
||||
self._get_source_columns_cache[cache_key] = columns
|
||||
|
||||
return self._get_source_columns_cache[cache_key]
|
||||
|
||||
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
|
||||
if self._source_columns is None:
|
||||
|
|
|
@ -29,7 +29,7 @@ class Scope:
|
|||
Selection scope.
|
||||
|
||||
Attributes:
|
||||
expression (exp.Select|exp.Union): Root expression of this scope
|
||||
expression (exp.Select|exp.SetOperation): Root expression of this scope
|
||||
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
|
||||
a Table expression or another Scope instance. For example:
|
||||
SELECT * FROM x {"x": Table(this="x")}
|
||||
|
@ -233,7 +233,7 @@ class Scope:
|
|||
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
|
||||
|
||||
Returns:
|
||||
list[exp.Select | exp.Union]: subqueries
|
||||
list[exp.Select | exp.SetOperation]: subqueries
|
||||
"""
|
||||
self._ensure_collected()
|
||||
return self._subqueries
|
||||
|
@ -339,7 +339,7 @@ class Scope:
|
|||
sources in the current scope.
|
||||
"""
|
||||
if self._external_columns is None:
|
||||
if isinstance(self.expression, exp.Union):
|
||||
if isinstance(self.expression, exp.SetOperation):
|
||||
left, right = self.union_scopes
|
||||
self._external_columns = left.external_columns + right.external_columns
|
||||
else:
|
||||
|
@ -535,7 +535,7 @@ def _traverse_scope(scope):
|
|||
|
||||
if isinstance(expression, exp.Select):
|
||||
yield from _traverse_select(scope)
|
||||
elif isinstance(expression, exp.Union):
|
||||
elif isinstance(expression, exp.SetOperation):
|
||||
yield from _traverse_ctes(scope)
|
||||
yield from _traverse_union(scope)
|
||||
return
|
||||
|
@ -588,7 +588,7 @@ def _traverse_union(scope):
|
|||
scope_type=ScopeType.UNION,
|
||||
)
|
||||
|
||||
if isinstance(expression, exp.Union):
|
||||
if isinstance(expression, exp.SetOperation):
|
||||
yield from _traverse_ctes(new_scope)
|
||||
|
||||
union_scope_stack.append(new_scope)
|
||||
|
@ -620,7 +620,7 @@ def _traverse_ctes(scope):
|
|||
if with_ and with_.recursive:
|
||||
union = cte.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
if isinstance(union, exp.SetOperation):
|
||||
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
child_scope = None
|
||||
|
|
|
@ -6,7 +6,6 @@ import functools
|
|||
import itertools
|
||||
import typing as t
|
||||
from collections import deque, defaultdict
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
|
||||
import sqlglot
|
||||
|
@ -347,8 +346,8 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
return expression
|
||||
|
||||
if l.is_number and r.is_number:
|
||||
l = float(l.name)
|
||||
r = float(r.name)
|
||||
l = l.to_py()
|
||||
r = r.to_py()
|
||||
elif l.is_string and r.is_string:
|
||||
l = l.name
|
||||
r = r.name
|
||||
|
@ -626,13 +625,8 @@ def simplify_literals(expression, root=True):
|
|||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
|
||||
if isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
if value[0] == "-":
|
||||
return exp.Literal.number(value[1:])
|
||||
return exp.Literal.number(f"-{value}")
|
||||
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
|
||||
return expression.this.this
|
||||
|
||||
if type(expression) in INVERSE_DATE_OPS:
|
||||
return _simplify_binary(expression, expression.this, expression.interval()) or expression
|
||||
|
@ -650,7 +644,7 @@ def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
|
|||
this = expr.this
|
||||
|
||||
if isinstance(expr, exp.Cast) and this.is_int:
|
||||
num = int(this.name)
|
||||
num = this.to_py()
|
||||
|
||||
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
|
||||
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
|
||||
|
@ -690,8 +684,8 @@ def _simplify_binary(expression, a, b):
|
|||
return exp.null()
|
||||
|
||||
if a.is_number and b.is_number:
|
||||
num_a = int(a.name) if a.is_int else Decimal(a.name)
|
||||
num_b = int(b.name) if b.is_int else Decimal(b.name)
|
||||
num_a = a.to_py()
|
||||
num_b = b.to_py()
|
||||
|
||||
if isinstance(expression, exp.Add):
|
||||
return exp.Literal.number(num_a + num_b)
|
||||
|
@ -1206,7 +1200,7 @@ def _is_date_literal(expression: exp.Expression) -> bool:
|
|||
|
||||
def extract_interval(expression):
|
||||
try:
|
||||
n = int(expression.name)
|
||||
n = int(expression.this.to_py())
|
||||
unit = expression.text("unit").lower()
|
||||
return interval(unit, n)
|
||||
except (UnsupportedUnit, ModuleNotFoundError, ValueError):
|
||||
|
|
|
@ -48,7 +48,7 @@ def unnest(select, parent_select, next_alias_name):
|
|||
):
|
||||
return
|
||||
|
||||
if isinstance(select, exp.Union):
|
||||
if isinstance(select, exp.SetOperation):
|
||||
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
|
||||
|
||||
alias = next_alias_name()
|
||||
|
|
|
@ -150,6 +150,7 @@ class Parser(metaclass=_Parser):
|
|||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
|
||||
"HEX": build_hex,
|
||||
"JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract),
|
||||
"JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar),
|
||||
"JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar),
|
||||
|
@ -157,11 +158,16 @@ class Parser(metaclass=_Parser):
|
|||
"LOG": build_logarithm,
|
||||
"LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)),
|
||||
"LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)),
|
||||
"LOWER": build_lower,
|
||||
"MOD": build_mod,
|
||||
"SCOPE_RESOLUTION": lambda args: exp.ScopeResolution(expression=seq_get(args, 0))
|
||||
if len(args) != 2
|
||||
else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"TO_HEX": build_hex,
|
||||
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
|
||||
this=exp.Cast(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -170,11 +176,9 @@ class Parser(metaclass=_Parser):
|
|||
start=exp.Literal.number(1),
|
||||
length=exp.Literal.number(10),
|
||||
),
|
||||
"VAR_MAP": build_var_map,
|
||||
"LOWER": build_lower,
|
||||
"UNNEST": lambda args: exp.Unnest(expressions=ensure_list(seq_get(args, 0))),
|
||||
"UPPER": build_upper,
|
||||
"HEX": build_hex,
|
||||
"TO_HEX": build_hex,
|
||||
"VAR_MAP": build_var_map,
|
||||
}
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
|
@ -295,6 +299,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ROWVERSION,
|
||||
TokenType.IMAGE,
|
||||
TokenType.VARIANT,
|
||||
TokenType.VECTOR,
|
||||
TokenType.OBJECT,
|
||||
TokenType.OBJECT_IDENTIFIER,
|
||||
TokenType.INET,
|
||||
|
@ -670,6 +675,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Properties: lambda self: self._parse_properties(),
|
||||
exp.Qualify: lambda self: self._parse_qualify(),
|
||||
exp.Returning: lambda self: self._parse_returning(),
|
||||
exp.Select: lambda self: self._parse_select(),
|
||||
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
|
||||
exp.Table: lambda self: self._parse_table_parts(),
|
||||
exp.TableAlias: lambda self: self._parse_table_alias(),
|
||||
|
@ -818,6 +824,7 @@ class Parser(metaclass=_Parser):
|
|||
"DETERMINISTIC": lambda self: self.expression(
|
||||
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
|
||||
),
|
||||
"DYNAMIC": lambda self: self.expression(exp.DynamicProperty),
|
||||
"DISTKEY": lambda self: self._parse_distkey(),
|
||||
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
|
||||
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||
|
@ -868,6 +875,7 @@ class Parser(metaclass=_Parser):
|
|||
"SAMPLE": lambda self: self.expression(
|
||||
exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise()
|
||||
),
|
||||
"SECURE": lambda self: self.expression(exp.SecureProperty),
|
||||
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
|
||||
"SETTINGS": lambda self: self.expression(
|
||||
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
|
||||
|
@ -990,6 +998,9 @@ class Parser(metaclass=_Parser):
|
|||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
|
||||
"CASE": lambda self: self._parse_case(),
|
||||
"CONNECT_BY_ROOT": lambda self: self.expression(
|
||||
exp.ConnectByRoot, this=self._parse_column()
|
||||
),
|
||||
"IF": lambda self: self._parse_if(),
|
||||
"NEXT": lambda self: self._parse_next_value_for(),
|
||||
}
|
||||
|
@ -1118,9 +1129,15 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",))
|
||||
|
||||
SCHEMA_BINDING_OPTIONS: OPTIONS_TYPE = {
|
||||
"TYPE": ("EVOLUTION",),
|
||||
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
|
||||
}
|
||||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
CLONE_KEYWORDS = {"CLONE", "COPY"}
|
||||
HISTORICAL_DATA_PREFIX = {"AT", "BEFORE", "END"}
|
||||
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
|
||||
|
||||
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"}
|
||||
|
@ -1184,8 +1201,8 @@ class Parser(metaclass=_Parser):
|
|||
STRING_ALIASES = False
|
||||
|
||||
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
|
||||
MODIFIERS_ATTACHED_TO_UNION = True
|
||||
UNION_MODIFIERS = {"order", "limit", "offset"}
|
||||
MODIFIERS_ATTACHED_TO_SET_OP = True
|
||||
SET_OP_MODIFIERS = {"order", "limit", "offset"}
|
||||
|
||||
# Whether to parse IF statements that aren't followed by a left parenthesis as commands
|
||||
NO_PAREN_IF_COMMANDS = True
|
||||
|
@ -1193,8 +1210,8 @@ class Parser(metaclass=_Parser):
|
|||
# Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres)
|
||||
JSON_ARROWS_REQUIRE_JSON_TYPE = False
|
||||
|
||||
# Whether the `:` operator is used to extract a value from a JSON document
|
||||
COLON_IS_JSON_EXTRACT = False
|
||||
# Whether the `:` operator is used to extract a value from a VARIANT column
|
||||
COLON_IS_VARIANT_EXTRACT = False
|
||||
|
||||
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
|
||||
# If this is True and '(' is not found, the keyword will be treated as an identifier
|
||||
|
@ -1466,9 +1483,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _try_parse(self, parse_method: t.Callable[[], T], retreat: bool = False) -> t.Optional[T]:
|
||||
"""
|
||||
Attemps to backtrack if a parse function that contains a try/catch internally raises an error. This behavior can
|
||||
be different depending on the uset-set ErrorLevel, so _try_parse aims to solve this by setting & resetting
|
||||
the parser state accordingly
|
||||
Attemps to backtrack if a parse function that contains a try/catch internally raises an error.
|
||||
This behavior can be different depending on the uset-set ErrorLevel, so _try_parse aims to
|
||||
solve this by setting & resetting the parser state accordingly
|
||||
"""
|
||||
index = self._index
|
||||
error_level = self.error_level
|
||||
|
@ -2005,6 +2022,12 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.SERDE_PROPERTIES, advance=False):
|
||||
return self._parse_serde_properties(with_=True)
|
||||
|
||||
if self._match(TokenType.SCHEMA):
|
||||
return self.expression(
|
||||
exp.WithSchemaBindingProperty,
|
||||
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
|
||||
)
|
||||
|
||||
if not self._next:
|
||||
return None
|
||||
|
||||
|
@ -2899,7 +2922,7 @@ class Parser(metaclass=_Parser):
|
|||
continue
|
||||
break
|
||||
|
||||
if self.SUPPORTS_IMPLICIT_UNNEST and this and "from" in this.args:
|
||||
if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from"):
|
||||
this = self._implicit_unnests_to_explicit(this)
|
||||
|
||||
return this
|
||||
|
@ -3187,6 +3210,8 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
where = self._parse_where()
|
||||
|
||||
on = self._parse_field() if self._match(TokenType.ON) else None
|
||||
|
||||
return self.expression(
|
||||
exp.IndexParameters,
|
||||
using=using,
|
||||
|
@ -3196,6 +3221,7 @@ class Parser(metaclass=_Parser):
|
|||
where=where,
|
||||
with_storage=with_storage,
|
||||
tablespace=tablespace,
|
||||
on=on,
|
||||
)
|
||||
|
||||
def _parse_index(
|
||||
|
@ -3959,7 +3985,7 @@ class Parser(metaclass=_Parser):
|
|||
token_type = self._prev.token_type
|
||||
|
||||
if token_type == TokenType.UNION:
|
||||
operation = exp.Union
|
||||
operation: t.Type[exp.SetOperation] = exp.Union
|
||||
elif token_type == TokenType.EXCEPT:
|
||||
operation = exp.Except
|
||||
else:
|
||||
|
@ -3979,11 +4005,11 @@ class Parser(metaclass=_Parser):
|
|||
expression=expression,
|
||||
)
|
||||
|
||||
if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
|
||||
if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP:
|
||||
expression = this.expression
|
||||
|
||||
if expression:
|
||||
for arg in self.UNION_MODIFIERS:
|
||||
for arg in self.SET_OP_MODIFIERS:
|
||||
expr = expression.args.get(arg)
|
||||
if expr:
|
||||
this.set(arg, expr.pop())
|
||||
|
@ -4122,7 +4148,7 @@ class Parser(metaclass=_Parser):
|
|||
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
|
||||
# each INTERVAL expression into this canonical form so it's easy to transpile
|
||||
if this and this.is_number:
|
||||
this = exp.Literal.string(this.name)
|
||||
this = exp.Literal.string(this.to_py())
|
||||
elif this and this.is_string:
|
||||
parts = exp.INTERVAL_STRING_RE.findall(this.name)
|
||||
if len(parts) == 1:
|
||||
|
@ -4286,8 +4312,8 @@ class Parser(metaclass=_Parser):
|
|||
identifier = allow_identifiers and self._parse_id_var(
|
||||
any_token=False, tokens=(TokenType.VAR,)
|
||||
)
|
||||
if identifier:
|
||||
tokens = self.dialect.tokenize(identifier.name)
|
||||
if isinstance(identifier, exp.Identifier):
|
||||
tokens = self.dialect.tokenize(identifier.sql(dialect=self.dialect))
|
||||
|
||||
if len(tokens) != 1:
|
||||
self.raise_error("Unexpected identifier", self._prev)
|
||||
|
@ -4370,6 +4396,10 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
expressions = self._parse_csv(self._parse_type_size)
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/data-types-vector
|
||||
if type_token == TokenType.VECTOR and len(expressions) == 2:
|
||||
expressions[0] = exp.DataType.build(expressions[0].name, dialect=self.dialect)
|
||||
|
||||
if not expressions or not self._match(TokenType.R_PAREN):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
@ -4481,10 +4511,22 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
this = (
|
||||
self._parse_type(parse_interval=False, fallback_to_identifier=True)
|
||||
or self._parse_id_var()
|
||||
)
|
||||
|
||||
if (
|
||||
self._curr
|
||||
and self._next
|
||||
and self._curr.token_type in self.TYPE_TOKENS
|
||||
and self._next.token_type in self.TYPE_TOKENS
|
||||
):
|
||||
# Takes care of special cases like `STRUCT<list ARRAY<...>>` where the identifier is also a
|
||||
# type token. Without this, the list will be parsed as a type and we'll eventually crash
|
||||
this = self._parse_id_var()
|
||||
else:
|
||||
this = (
|
||||
self._parse_type(parse_interval=False, fallback_to_identifier=True)
|
||||
or self._parse_id_var()
|
||||
)
|
||||
|
||||
self._match(TokenType.COLON)
|
||||
|
||||
if (
|
||||
|
@ -4527,7 +4569,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_colon_as_json_extract(
|
||||
def _parse_colon_as_variant_extract(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
casts = []
|
||||
|
@ -4560,11 +4602,14 @@ class Parser(metaclass=_Parser):
|
|||
if path:
|
||||
json_path.append(self._find_sql(self._tokens[start_index], end_token))
|
||||
|
||||
# The VARIANT extract in Snowflake/Databricks is parsed as a JSONExtract; Snowflake uses the json_path in GET_PATH() while
|
||||
# Databricks transforms it back to the colon/dot notation
|
||||
if json_path:
|
||||
this = self.expression(
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
|
||||
variant_extract=True,
|
||||
)
|
||||
|
||||
while casts:
|
||||
|
@ -4572,6 +4617,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_dcolon(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_types()
|
||||
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bracket(this)
|
||||
|
||||
|
@ -4580,7 +4628,7 @@ class Parser(metaclass=_Parser):
|
|||
op = self.COLUMN_OPERATORS.get(op_token)
|
||||
|
||||
if op_token == TokenType.DCOLON:
|
||||
field = self._parse_types()
|
||||
field = self._parse_dcolon()
|
||||
if not field:
|
||||
self.raise_error("Expected type")
|
||||
elif op and self._curr:
|
||||
|
@ -4618,7 +4666,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
this = self._parse_bracket(this)
|
||||
|
||||
return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this
|
||||
return self._parse_colon_as_variant_extract(this) if self.COLON_IS_VARIANT_EXTRACT else this
|
||||
|
||||
def _parse_primary(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_set(self.PRIMARY_PARSERS):
|
||||
|
@ -5312,8 +5360,8 @@ class Parser(metaclass=_Parser):
|
|||
order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
|
||||
)
|
||||
|
||||
def _parse_extract(self) -> exp.Extract:
|
||||
this = self._parse_function() or self._parse_var() or self._parse_type()
|
||||
def _parse_extract(self) -> exp.Extract | exp.Anonymous:
|
||||
this = self._parse_function() or self._parse_var_or_string(upper=True)
|
||||
|
||||
if self._match(TokenType.FROM):
|
||||
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
|
||||
|
@ -5362,6 +5410,7 @@ class Parser(metaclass=_Parser):
|
|||
self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE,
|
||||
)
|
||||
),
|
||||
safe=safe,
|
||||
)
|
||||
|
||||
if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
|
||||
|
@ -5942,8 +5991,8 @@ class Parser(metaclass=_Parser):
|
|||
return self._prev
|
||||
return None
|
||||
|
||||
def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_var() or self._parse_string()
|
||||
def _parse_var_or_string(self, upper: bool = False) -> t.Optional[exp.Expression]:
|
||||
return self._parse_string() or self._parse_var(any_token=True, upper=upper)
|
||||
|
||||
def _parse_primary_or_var(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_primary() or self._parse_var(any_token=True)
|
||||
|
|
|
@ -108,7 +108,7 @@ class Step:
|
|||
|
||||
if isinstance(expression, exp.Select) and from_:
|
||||
step = Scan.from_expression(from_.this, ctes)
|
||||
elif isinstance(expression, exp.Union):
|
||||
elif isinstance(expression, exp.SetOperation):
|
||||
step = SetOperation.from_expression(expression, ctes)
|
||||
else:
|
||||
step = Scan()
|
||||
|
@ -124,13 +124,13 @@ class Step:
|
|||
|
||||
projections = [] # final selects in this chain of steps representing a select
|
||||
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
|
||||
aggregations = set()
|
||||
aggregations = {}
|
||||
next_operand_name = name_sequence("_a_")
|
||||
|
||||
def extract_agg_operands(expression):
|
||||
agg_funcs = tuple(expression.find_all(exp.AggFunc))
|
||||
if agg_funcs:
|
||||
aggregations.add(expression)
|
||||
aggregations[expression] = None
|
||||
|
||||
for agg in agg_funcs:
|
||||
for operand in agg.unnest_operands():
|
||||
|
@ -426,7 +426,7 @@ class SetOperation(Step):
|
|||
def from_expression(
|
||||
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||
) -> SetOperation:
|
||||
assert isinstance(expression, exp.Union)
|
||||
assert isinstance(expression, exp.SetOperation)
|
||||
|
||||
left = Step.from_expression(expression.left, ctes)
|
||||
# SELECT 1 UNION SELECT 2 <-- these subqueries don't have names
|
||||
|
|
|
@ -386,6 +386,8 @@ class MappingSchema(AbstractMappingSchema, Schema):
|
|||
|
||||
if not isinstance(columns, dict):
|
||||
raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
|
||||
if not columns:
|
||||
raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
|
||||
if isinstance(first(columns.values()), dict):
|
||||
raise SchemaError(
|
||||
error_msg.format(
|
||||
|
|
|
@ -202,6 +202,7 @@ class TokenType(AutoName):
|
|||
SIMPLEAGGREGATEFUNCTION = auto()
|
||||
TDIGEST = auto()
|
||||
UNKNOWN = auto()
|
||||
VECTOR = auto()
|
||||
|
||||
# keywords
|
||||
ALIAS = auto()
|
||||
|
@ -526,6 +527,7 @@ class _Tokenizer(type):
|
|||
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
|
||||
},
|
||||
heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
|
||||
string_escapes_allowed_in_raw_strings=klass.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS,
|
||||
)
|
||||
token_types = RsTokenTypeSettings(
|
||||
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
|
||||
|
@ -602,6 +604,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
# Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc
|
||||
HEREDOC_STRING_ALTERNATIVE = TokenType.VAR
|
||||
|
||||
# Whether string escape characters function as such when placed within raw strings
|
||||
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = True
|
||||
|
||||
# Autofilled
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
|
||||
|
@ -877,6 +882,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"DATERANGE": TokenType.DATERANGE,
|
||||
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
|
||||
"UNIQUE": TokenType.UNIQUE,
|
||||
"VECTOR": TokenType.VECTOR,
|
||||
"STRUCT": TokenType.STRUCT,
|
||||
"SEQUENCE": TokenType.SEQUENCE,
|
||||
"VARIANT": TokenType.VARIANT,
|
||||
|
@ -1162,10 +1168,22 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
# Skip the comment's start delimiter
|
||||
self._advance(comment_start_size)
|
||||
|
||||
comment_count = 1
|
||||
comment_end_size = len(comment_end)
|
||||
while not self._end and self._chars(comment_end_size) != comment_end:
|
||||
|
||||
while not self._end:
|
||||
if self._chars(comment_end_size) == comment_end:
|
||||
comment_count -= 1
|
||||
if not comment_count:
|
||||
break
|
||||
|
||||
self._advance(alnum=True)
|
||||
|
||||
# Nested comments are allowed by some dialects, e.g. databricks, duckdb, postgres
|
||||
if not self._end and self._chars(comment_end_size) == comment_start:
|
||||
self._advance(comment_start_size)
|
||||
comment_count += 1
|
||||
|
||||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
|
@ -1280,7 +1298,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
else:
|
||||
tag = self._extract_string(
|
||||
end,
|
||||
unescape_sequences=False,
|
||||
raw_string=True,
|
||||
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
|
||||
)
|
||||
|
||||
|
@ -1297,7 +1315,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
return False
|
||||
|
||||
self._advance(len(start))
|
||||
text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING)
|
||||
text = self._extract_string(end, raw_string=token_type == TokenType.RAW_STRING)
|
||||
|
||||
if base:
|
||||
try:
|
||||
|
@ -1333,7 +1351,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self,
|
||||
delimiter: str,
|
||||
escapes: t.Optional[t.Set[str]] = None,
|
||||
unescape_sequences: bool = True,
|
||||
raw_string: bool = False,
|
||||
raise_unmatched: bool = True,
|
||||
) -> str:
|
||||
text = ""
|
||||
|
@ -1342,7 +1360,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
while True:
|
||||
if (
|
||||
unescape_sequences
|
||||
not raw_string
|
||||
and self.dialect.UNESCAPED_SEQUENCES
|
||||
and self._peek
|
||||
and self._char in self.STRING_ESCAPES
|
||||
|
@ -1353,7 +1371,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
text += unescaped_sequence
|
||||
continue
|
||||
if (
|
||||
self._char in escapes
|
||||
(self.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS or not raw_string)
|
||||
and self._char in escapes
|
||||
and (self._peek == delimiter or self._peek in escapes)
|
||||
and (self._char not in self._QUOTES or self._char == self._peek)
|
||||
):
|
||||
|
|
|
@ -9,6 +9,52 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.generator import Generator
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
|
||||
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
|
||||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
|
||||
Returns:
|
||||
Function that can be used as a generator transform.
|
||||
"""
|
||||
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
|
||||
expression = transforms[0](expression)
|
||||
for transform in transforms[1:]:
|
||||
expression = transform(expression)
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
return _sql_handler(expression)
|
||||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
if expression_type is type(expression):
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
|
||||
return transforms_handler(self, expression)
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
|
||||
return _to_sql
|
||||
|
||||
|
||||
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Replace references to select aliases in GROUP BY clauses.
|
||||
|
@ -393,7 +439,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
|
|||
for cte in expression.expressions:
|
||||
if not cte.args["alias"].columns:
|
||||
query = cte.this
|
||||
if isinstance(query, exp.Union):
|
||||
if isinstance(query, exp.SetOperation):
|
||||
query = query.this
|
||||
|
||||
cte.args["alias"].set(
|
||||
|
@ -623,47 +669,103 @@ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def preprocess(
|
||||
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
|
||||
) -> t.Callable[[Generator, exp.Expression], str]:
|
||||
def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Creates a new transform by chaining a sequence of transformations and converts the resulting
|
||||
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
|
||||
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
|
||||
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
|
||||
If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
|
||||
|
||||
For example,
|
||||
SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to
|
||||
SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
|
||||
|
||||
Args:
|
||||
transforms: sequence of transform functions. These will be called in order.
|
||||
expression: The AST to remove join marks from.
|
||||
|
||||
Returns:
|
||||
Function that can be used as a generator transform.
|
||||
The AST with join marks removed.
|
||||
"""
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
|
||||
def _to_sql(self, expression: exp.Expression) -> str:
|
||||
expression_type = type(expression)
|
||||
for scope in traverse_scope(expression):
|
||||
query = scope.expression
|
||||
|
||||
expression = transforms[0](expression)
|
||||
for transform in transforms[1:]:
|
||||
expression = transform(expression)
|
||||
where = query.args.get("where")
|
||||
joins = query.args.get("joins")
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
return _sql_handler(expression)
|
||||
if not where or not joins:
|
||||
continue
|
||||
|
||||
transforms_handler = self.TRANSFORMS.get(type(expression))
|
||||
if transforms_handler:
|
||||
if expression_type is type(expression):
|
||||
if isinstance(expression, exp.Func):
|
||||
return self.function_fallback_sql(expression)
|
||||
query_from = query.args["from"]
|
||||
|
||||
# Ensures we don't enter an infinite loop. This can happen when the original expression
|
||||
# has the same type as the final expression and there's no _sql method available for it,
|
||||
# because then it'd re-enter _to_sql.
|
||||
raise ValueError(
|
||||
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
|
||||
)
|
||||
# These keep track of the joins to be replaced
|
||||
new_joins: t.Dict[str, exp.Join] = {}
|
||||
old_joins = {join.alias_or_name: join for join in joins}
|
||||
|
||||
return transforms_handler(self, expression)
|
||||
for column in scope.columns:
|
||||
if not column.args.get("join_mark"):
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
|
||||
predicate = column.find_ancestor(exp.Predicate, exp.Select)
|
||||
assert isinstance(
|
||||
predicate, exp.Binary
|
||||
), "Columns can only be marked with (+) when involved in a binary operation"
|
||||
|
||||
return _to_sql
|
||||
predicate_parent = predicate.parent
|
||||
join_predicate = predicate.pop()
|
||||
|
||||
left_columns = [
|
||||
c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
|
||||
]
|
||||
right_columns = [
|
||||
c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
|
||||
]
|
||||
|
||||
assert not (
|
||||
left_columns and right_columns
|
||||
), "The (+) marker cannot appear in both sides of a binary predicate"
|
||||
|
||||
marked_column_tables = set()
|
||||
for col in left_columns or right_columns:
|
||||
table = col.table
|
||||
assert table, f"Column {col} needs to be qualified with a table"
|
||||
|
||||
col.set("join_mark", False)
|
||||
marked_column_tables.add(table)
|
||||
|
||||
assert (
|
||||
len(marked_column_tables) == 1
|
||||
), "Columns of only a single table can be marked with (+) in a given binary predicate"
|
||||
|
||||
join_this = old_joins.get(col.table, query_from).this
|
||||
new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
|
||||
|
||||
# Upsert new_join into new_joins dictionary
|
||||
new_join_alias_or_name = new_join.alias_or_name
|
||||
existing_join = new_joins.get(new_join_alias_or_name)
|
||||
if existing_join:
|
||||
existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
|
||||
else:
|
||||
new_joins[new_join_alias_or_name] = new_join
|
||||
|
||||
# If the parent of the target predicate is a binary node, then it now has only one child
|
||||
if isinstance(predicate_parent, exp.Binary):
|
||||
if predicate_parent.left is None:
|
||||
predicate_parent.replace(predicate_parent.right)
|
||||
else:
|
||||
predicate_parent.replace(predicate_parent.left)
|
||||
|
||||
if query_from.alias_or_name in new_joins:
|
||||
only_old_joins = old_joins.keys() - new_joins.keys()
|
||||
assert (
|
||||
len(only_old_joins) >= 1
|
||||
), "Cannot determine which table to use in the new FROM clause"
|
||||
|
||||
new_from_name = list(only_old_joins)[0]
|
||||
query.set("from", exp.From(this=old_joins[new_from_name].this))
|
||||
|
||||
query.set("joins", list(new_joins.values()))
|
||||
|
||||
if not where.this:
|
||||
where.pop()
|
||||
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue