1
0
Fork 0

Merging upstream version 25.5.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:41:14 +01:00
parent 298e7a8147
commit 029b9c2c73
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
136 changed files with 80990 additions and 72541 deletions

View file

@ -13,7 +13,6 @@ from sqlglot.dialects.dialect import (
date_add_interval_sql,
datestrtodate_sql,
build_formatted_time,
build_timestamp_from_parts,
filter_array_using_unnest,
if_sql,
inline_array_unless_query,
@ -202,10 +201,35 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
def _build_time(args: t.List) -> exp.Func:
if len(args) == 1:
return exp.TsOrDsToTime(this=args[0])
if len(args) == 3:
return exp.TimeFromParts.from_arg_list(args)
if len(args) == 2:
return exp.Time.from_arg_list(args)
return exp.TimeFromParts.from_arg_list(args)
return exp.Anonymous(this="TIME", expressions=args)
def _build_datetime(args: t.List) -> exp.Func:
if len(args) == 1:
return exp.TsOrDsToTimestamp.from_arg_list(args)
if len(args) == 2:
return exp.Datetime.from_arg_list(args)
return exp.TimestampFromParts.from_arg_list(args)
def _str_to_datetime_sql(
self: BigQuery.Generator, expression: exp.StrToDate | exp.StrToTime
) -> str:
this = self.sql(expression, "this")
dtype = "DATE" if isinstance(expression, exp.StrToDate) else "TIMESTAMP"
if expression.args.get("safe"):
fmt = self.format_time(
expression,
self.dialect.INVERSE_FORMAT_MAPPING,
self.dialect.INVERSE_FORMAT_TRIE,
)
return f"SAFE_CAST({this} AS {dtype} FORMAT {fmt})"
fmt = self.format_time(expression)
return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone"))
class BigQuery(Dialect):
@ -215,6 +239,8 @@ class BigQuery(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
LOG_BASE_FIRST = False
HEX_LOWERCASE = True
FORCE_EARLY_ALIAS_REF_EXPANSION = True
EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
@ -249,7 +275,10 @@ class BigQuery(Dialect):
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
def normalize_identifier(self, expression: E) -> E:
if isinstance(expression, exp.Identifier):
if (
isinstance(expression, exp.Identifier)
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
):
parent = expression.parent
while isinstance(parent, exp.Dot):
parent = parent.parent
@ -308,6 +337,7 @@ class BigQuery(Dialect):
}
KEYWORDS.pop("DIV")
KEYWORDS.pop("VALUES")
KEYWORDS.pop("/*+")
class Parser(parser.Parser):
PREFIXED_PIVOT_COLUMNS = True
@ -323,7 +353,7 @@ class BigQuery(Dialect):
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATETIME": build_timestamp_from_parts,
"DATETIME": _build_datetime,
"DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
@ -334,6 +364,7 @@ class BigQuery(Dialect):
"JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
),
"LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _build_to_hex,
"PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
@ -552,7 +583,7 @@ class BigQuery(Dialect):
return bracket
class Generator(generator.Generator):
EXPLICIT_UNION = True
EXPLICIT_SET_OP = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
QUERY_HINTS = False
@ -644,10 +675,8 @@ class BigQuery(Dialect):
exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this),
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
exp.StrToDate: _str_to_datetime_sql,
exp.StrToTime: _str_to_datetime_sql,
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
exp.TimeFromParts: rename_func("TIME"),
exp.TimestampFromParts: rename_func("DATETIME"),
@ -661,6 +690,7 @@ class BigQuery(Dialect):
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToTime: rename_func("TIME"),
exp.TsOrDsToTimestamp: rename_func("DATETIME"),
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixDate: rename_func("UNIX_DATE"),
exp.UnixToTime: _unix_to_time_sql,

View file

@ -6,8 +6,8 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
build_date_delta,
build_formatted_time,
date_delta_sql,
inline_array_sql,
json_extract_segments,
json_path_key_only_name,
@ -17,10 +17,14 @@ from sqlglot.dialects.dialect import (
sha256_sql,
var_map_sql,
timestamptrunc_sql,
unit_to_var,
)
from sqlglot.generator import Generator
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
DATEΤΙΜΕ_DELTA = t.Union[exp.DateAdd, exp.DateDiff, exp.DateSub, exp.TimestampSub, exp.TimestampAdd]
def _build_date_format(args: t.List) -> exp.TimeToStr:
expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args)
@ -77,12 +81,28 @@ def _build_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
def _datetime_delta_sql(name: str) -> t.Callable[[Generator, DATEΤΙΜΕ_DELTA], str]:
def _delta_sql(self: Generator, expression: DATEΤΙΜΕ_DELTA) -> str:
if not expression.unit:
return rename_func(name)(self, expression)
return self.func(
name,
unit_to_var(expression),
expression.expression,
expression.this,
)
return _delta_sql
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
SAFE_DIVISION = True
LOG_BASE_FIRST: t.Optional[bool] = None
FORCE_EARLY_ALIAS_REF_EXPANSION = True
UNESCAPED_SEQUENCES = {
"\\0": "\0",
@ -128,6 +148,7 @@ class ClickHouse(Dialect):
"SYSTEM": TokenType.COMMAND,
"PREWHERE": TokenType.PREWHERE,
}
KEYWORDS.pop("/*+")
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -138,7 +159,7 @@ class ClickHouse(Dialect):
# Tested in ClickHouse's playground, it seems that the following two queries do the same thing
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
MODIFIERS_ATTACHED_TO_SET_OP = False
INTERVAL_SPANS = False
FUNCTIONS = {
@ -146,19 +167,13 @@ class ClickHouse(Dialect):
"ANY": exp.AnyValue.from_arg_list,
"ARRAYSUM": exp.ArraySum.from_arg_list,
"COUNTIF": _build_count_if,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_ADD": build_date_delta(exp.DateAdd, default_unit=None),
"DATEADD": build_date_delta(exp.DateAdd, default_unit=None),
"DATE_DIFF": build_date_delta(exp.DateDiff, default_unit=None),
"DATEDIFF": build_date_delta(exp.DateDiff, default_unit=None),
"DATE_FORMAT": _build_date_format,
"DATE_SUB": build_date_delta(exp.DateSub, default_unit=None),
"DATESUB": build_date_delta(exp.DateSub, default_unit=None),
"FORMATDATETIME": _build_date_format,
"JSONEXTRACTSTRING": build_json_extract_path(
exp.JSONExtractScalar, zero_based_indexing=False
@ -167,6 +182,10 @@ class ClickHouse(Dialect):
"MATCH": exp.RegexpLike.from_arg_list,
"RANDCANONICAL": exp.Rand.from_arg_list,
"TUPLE": exp.Struct.from_arg_list,
"TIMESTAMP_SUB": build_date_delta(exp.TimestampSub, default_unit=None),
"TIMESTAMPSUB": build_date_delta(exp.TimestampSub, default_unit=None),
"TIMESTAMP_ADD": build_date_delta(exp.TimestampAdd, default_unit=None),
"TIMESTAMPADD": build_date_delta(exp.TimestampAdd, default_unit=None),
"UNIQ": exp.ApproxDistinct.from_arg_list,
"XOR": lambda args: exp.Xor(expressions=args),
"MD5": exp.MD5Digest.from_arg_list,
@ -389,6 +408,23 @@ class ClickHouse(Dialect):
"INDEX",
}
def _parse_extract(self) -> exp.Extract | exp.Anonymous:
index = self._index
this = self._parse_bitwise()
if self._match(TokenType.FROM):
self._retreat(index)
return super()._parse_extract()
# We return Anonymous here because extract and regexpExtract have different semantics,
# so parsing extract(foo, bar) into RegexpExtract can potentially break queries. E.g.,
# `extract('foobar', 'b')` works, but CH crashes for `regexpExtract('foobar', 'b')`.
#
# TODO: can we somehow convert the former into an equivalent `regexpExtract` call?
self._match(TokenType.COMMA)
return self.expression(
exp.Anonymous, this="extract", expressions=[this, self._parse_bitwise()]
)
def _parse_assignment(self) -> t.Optional[exp.Expression]:
this = super()._parse_assignment()
@ -657,6 +693,12 @@ class ClickHouse(Dialect):
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_SET_OP = True
GROUPINGS_SEP = ""
SET_OP_MODIFIERS = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@ -730,8 +772,9 @@ class ClickHouse(Dialect):
exp.ComputedColumnConstraint: lambda self,
e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}",
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
exp.DateAdd: _datetime_delta_sql("DATE_ADD"),
exp.DateDiff: _datetime_delta_sql("DATE_DIFF"),
exp.DateSub: _datetime_delta_sql("DATE_SUB"),
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
@ -754,6 +797,8 @@ class ClickHouse(Dialect):
exp.TimeToStr: lambda self, e: self.func(
"DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone")
),
exp.TimestampAdd: _datetime_delta_sql("TIMESTAMP_ADD"),
exp.TimestampSub: _datetime_delta_sql("TIMESTAMP_SUB"),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
exp.MD5Digest: rename_func("MD5"),
@ -773,12 +818,6 @@ class ClickHouse(Dialect):
exp.OnCluster: exp.Properties.Location.POST_NAME,
}
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
OUTER_UNION_MODIFIERS = False
# there's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
ON_CLUSTER_TARGETS = {

View file

@ -1,6 +1,8 @@
from __future__ import annotations
from sqlglot import exp, transforms
import typing as t
from sqlglot import exp, transforms, jsonpath
from sqlglot.dialects.dialect import (
date_delta_sql,
build_date_delta,
@ -10,27 +12,47 @@ from sqlglot.dialects.spark import Spark
from sqlglot.tokens import TokenType
def _build_json_extract(args: t.List) -> exp.JSONExtract:
# Transform GET_JSON_OBJECT(expr, '$.<path>') -> expr:<path>
this = args[0]
path = args[1].name.lstrip("$.")
return exp.JSONExtract(this=this, expression=path)
def _timestamp_diff(
self: Databricks.Generator, expression: exp.DatetimeDiff | exp.TimestampDiff
) -> str:
return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this)
def _jsonextract_sql(
self: Databricks.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
this = self.sql(expression, "this")
expr = self.sql(expression, "expression")
return f"{this}:{expr}"
class Databricks(Spark):
SAFE_DIVISION = False
COPY_PARAMS_ARE_CSV = False
class JSONPathTokenizer(jsonpath.JSONPathTokenizer):
IDENTIFIERS = ["`", '"']
class Parser(Spark.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = True
COLON_IS_JSON_EXTRACT = True
COLON_IS_VARIANT_EXTRACT = True
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
"DATEADD": build_date_delta(exp.DateAdd),
"DATE_ADD": build_date_delta(exp.DateAdd),
"DATEDIFF": build_date_delta(exp.DateDiff),
"DATE_DIFF": build_date_delta(exp.DateDiff),
"TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
"GET_JSON_OBJECT": _build_json_extract,
}
FACTOR = {
@ -42,6 +64,8 @@ class Databricks(Spark):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
COPY_PARAMS_ARE_WRAPPED = False
COPY_PARAMS_EQ_REQUIRED = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
QUOTE_JSON_PATH = False
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
@ -65,6 +89,9 @@ class Databricks(Spark):
transforms.unnest_to_explode,
]
),
exp.JSONExtract: _jsonextract_sql,
exp.JSONExtractScalar: _jsonextract_sql,
exp.JSONPathRoot: lambda *_: "",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -9,7 +9,7 @@ from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@ -122,15 +122,21 @@ class _Dialect(type):
)
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()}
klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING)
base = seq_get(bases, 0)
base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),)
base_parser = (getattr(base, "parser_class", Parser),)
base_generator = (getattr(base, "generator_class", Generator),)
klass.tokenizer_class = klass.__dict__.get(
"Tokenizer", type("Tokenizer", base_tokenizer, {})
)
klass.jsonpath_tokenizer_class = klass.__dict__.get(
"JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {})
)
klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
klass.generator_class = klass.__dict__.get(
"Generator", type("Generator", base_generator, {})
@ -164,6 +170,8 @@ class _Dialect(type):
klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
@ -232,6 +240,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""
COPY_PARAMS_ARE_CSV = True
"""Separator of COPY statement parameters."""
NORMALIZE_FUNCTIONS: bool | str = "upper"
"""
Determines how function names are going to be normalized.
@ -311,9 +322,44 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""
COPY_PARAMS_ARE_CSV = True
"""
Whether COPY statement parameters are separated by comma or whitespace
"""
FORCE_EARLY_ALIAS_REF_EXPANSION = False
"""
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS (
SELECT
1 AS id,
2 AS my_id
)
SELECT
id AS my_id
FROM
data
WHERE
my_id = 1
GROUP BY
my_id,
HAVING
my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except:
- BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1"
- Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
"""
EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False
"""Whether alias reference expansion before qualification should only happen for the GROUP BY clause."""
# --- Autofilled ---
tokenizer_class = Tokenizer
jsonpath_tokenizer_class = JSONPathTokenizer
parser_class = Parser
generator_class = Generator
@ -323,6 +369,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {}
INVERSE_FORMAT_TRIE: t.Dict = {}
ESCAPED_SEQUENCES: t.Dict[str, str] = {}
@ -342,8 +390,99 @@ class Dialect(metaclass=_Dialect):
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None
# Separator of COPY statement parameters
COPY_PARAMS_ARE_CSV = True
DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"DAY OF WEEK": "DAYOFWEEK",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"DAY OF YEAR": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MSECS": "MILLISECOND",
"MSECOND": "MILLISECOND",
"MSECONDS": "MILLISECOND",
"MILLISEC": "MILLISECOND",
"MILLISECS": "MILLISECOND",
"MILLISECON": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"USECS": "MICROSECOND",
"MICROSEC": "MICROSECOND",
"MICROSECS": "MICROSECOND",
"USECOND": "MICROSECOND",
"USECONDS": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH_SECOND": "EPOCH",
"EPOCH_SECONDS": "EPOCH",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
"DEC": "DECADE",
"DECS": "DECADE",
"DECADES": "DECADE",
"MIL": "MILLENIUM",
"MILS": "MILLENIUM",
"MILLENIA": "MILLENIUM",
"C": "CENTURY",
"CENT": "CENTURY",
"CENTS": "CENTURY",
"CENTURIES": "CENTURY",
}
@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
@ -371,8 +510,28 @@ class Dialect(metaclass=_Dialect):
return dialect
if isinstance(dialect, str):
try:
dialect_name, *kv_pairs = dialect.split(",")
kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
dialect_name, *kv_strings = dialect.split(",")
kv_pairs = (kv.split("=") for kv in kv_strings)
kwargs = {}
for pair in kv_pairs:
key = pair[0].strip()
value: t.Union[bool | str | None] = None
if len(pair) == 1:
# Default initialize standalone settings to True
value = True
elif len(pair) == 2:
value = pair[1].strip()
# Coerce the value to boolean if it matches to the truthy/falsy values below
value_lower = value.lower()
if value_lower in ("true", "1"):
value = True
elif value_lower in ("false", "0"):
value = False
kwargs[key] = value
except ValueError:
raise ValueError(
f"Invalid dialect format: '{dialect}'. "
@ -410,13 +569,15 @@ class Dialect(metaclass=_Dialect):
return expression
def __init__(self, **kwargs) -> None:
normalization_strategy = kwargs.get("normalization_strategy")
normalization_strategy = kwargs.pop("normalization_strategy", None)
if normalization_strategy is None:
self.normalization_strategy = self.NORMALIZATION_STRATEGY
else:
self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
self.settings = kwargs
def __eq__(self, other: t.Any) -> bool:
# Does not currently take dialect state into account
return type(self) == other
@ -518,9 +679,8 @@ class Dialect(metaclass=_Dialect):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"
try:
return parse_json_path(path_text)
return parse_json_path(path_text, self)
except ParseError as e:
logger.warning(f"Invalid JSON path syntax. {str(e)}")
@ -548,9 +708,11 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class(dialect=self)
return self._tokenizer
return self.tokenizer_class(dialect=self)
@property
def jsonpath_tokenizer(self) -> JSONPathTokenizer:
return self.jsonpath_tokenizer_class(dialect=self)
def parser(self, **opts) -> Parser:
return self.parser_class(dialect=self, **opts)
@ -739,13 +901,17 @@ def time_format(
def build_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
exp_class: t.Type[E],
unit_mapping: t.Optional[t.Dict[str, str]] = None,
default_unit: t.Optional[str] = "DAY",
) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
unit = None
if unit_based or default_unit:
unit = args[0] if unit_based else exp.Literal.string(default_unit)
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
return _builder
@ -803,19 +969,45 @@ def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.Timesta
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
if not expression.expression:
zone = expression.args.get("zone")
if not zone:
from sqlglot.optimizer.annotate_types import annotate_types
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
return self.sql(exp.cast(expression.this, target_type))
if expression.text("expression").lower() in TIMEZONES:
if zone.name.lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
zone=expression.expression,
zone=zone,
)
)
return self.func("TIMESTAMP", expression.this, expression.expression)
return self.func("TIMESTAMP", expression.this, zone)
def no_time_sql(self: Generator, expression: exp.Time) -> str:
# Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME)
this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)
expr = exp.cast(
exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME
)
return self.sql(expr)
def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str:
this = expression.this
expr = expression.expression
if expr.name.lower() in TIMEZONES:
# Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP)
this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ)
this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP)
return self.sql(this)
this = exp.cast(this, exp.DataType.Type.DATE)
expr = exp.cast(expr, exp.DataType.Type.TIME)
return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
def locate_to_strposition(args: t.List) -> exp.Expression:
@ -1058,6 +1250,25 @@ def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[
return exp.Var(this=default) if default else None
@t.overload
def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var:
pass
@t.overload
def map_date_part(
part: t.Optional[exp.Expression], dialect: DialectType = Dialect
) -> t.Optional[exp.Expression]:
pass
def map_date_part(part, dialect: DialectType = Dialect):
mapped = (
Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None
)
return exp.var(mapped) if mapped else part
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")

View file

@ -11,6 +11,15 @@ from sqlglot.dialects.dialect import (
from sqlglot.dialects.mysql import MySQL
def _lag_lead_sql(self, expression: exp.Lag | exp.Lead) -> str:
return self.func(
"LAG" if isinstance(expression, exp.Lag) else "LEAD",
expression.this,
expression.args.get("offset") or exp.Literal.number(1),
expression.args.get("default") or exp.null(),
)
class Doris(MySQL):
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
@ -56,6 +65,8 @@ class Doris(MySQL):
"GROUP_CONCAT", e.this, e.args.get("separator") or exp.Literal.string(",")
),
exp.JSONExtractScalar: lambda self, e: self.func("JSON_EXTRACT", e.this, e.expression),
exp.Lag: _lag_lead_sql,
exp.Lead: _lag_lead_sql,
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),

View file

@ -70,6 +70,9 @@ class Drill(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
KEYWORDS = tokens.Tokenizer.KEYWORDS.copy()
KEYWORDS.pop("/*+")
class Parser(parser.Parser):
STRICT_CAST = False

View file

@ -15,11 +15,13 @@ from sqlglot.dialects.dialect import (
build_default_decimal_type,
date_trunc_to_time,
datestrtodate_sql,
no_datetime_sql,
encode_decode_sql,
build_formatted_time,
inline_array_unless_query,
no_comment_column_constraint_sql,
no_safe_divide_sql,
no_time_sql,
no_timestamp_sql,
pivot_column_names,
regexp_extract_sql,
@ -218,6 +220,7 @@ class DuckDB(Dialect):
"TIMESTAMP_US": TokenType.TIMESTAMP,
"VARCHAR": TokenType.TEXT,
}
KEYWORDS.pop("/*+")
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -407,6 +410,7 @@ class DuckDB(Dialect):
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.Datetime: no_datetime_sql,
exp.DateToDi: lambda self,
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
@ -429,7 +433,6 @@ class DuckDB(Dialect):
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: rename_func("QUANTILE_CONT"),
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
@ -450,13 +453,12 @@ class DuckDB(Dialect):
exp.Split: rename_func("STR_SPLIT"),
exp.SortArray: _sort_array_sql,
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: self.func(
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
exp.Struct: _struct_sql,
exp.TimeAdd: _date_delta_sql,
exp.Time: no_time_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
@ -608,6 +610,24 @@ class DuckDB(Dialect):
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
def strtotime_sql(self, expression: exp.StrToTime) -> str:
if expression.args.get("safe"):
formatted_time = self.format_time(expression)
return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS TIMESTAMP)"
return str_to_time_sql(self, expression)
def strtodate_sql(self, expression: exp.StrToDate) -> str:
if expression.args.get("safe"):
formatted_time = self.format_time(expression)
return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS DATE)"
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
def parsejson_sql(self, expression: exp.ParseJSON) -> str:
arg = expression.this
if expression.args.get("safe"):
return self.sql(exp.case().when(exp.func("json_valid", arg), arg).else_(exp.null()))
return self.func("JSON", arg)
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
@ -728,3 +748,33 @@ class DuckDB(Dialect):
this = self.sql(expression, "this").rstrip(")")
return f"{this}{expression_sql})"
def length_sql(self, expression: exp.Length) -> str:
arg = expression.this
# Dialects like BQ and Snowflake also accept binary values as args, so
# DDB will attempt to infer the type or resort to case/when resolution
if not expression.args.get("binary") or arg.is_string:
return self.func("LENGTH", arg)
if not arg.type:
from sqlglot.optimizer.annotate_types import annotate_types
arg = annotate_types(arg)
if arg.is_type(*exp.DataType.TEXT_TYPES):
return self.func("LENGTH", arg)
# We need these casts to make duckdb's static type checker happy
blob = exp.cast(arg, exp.DataType.Type.VARBINARY)
varchar = exp.cast(arg, exp.DataType.Type.VARCHAR)
case = (
exp.case(self.func("TYPEOF", arg))
.when(
"'VARCHAR'", exp.Anonymous(this="LENGTH", expressions=[varchar])
) # anonymous to break length_sql recursion
.when("'BLOB'", self.func("OCTET_LENGTH", blob))
)
return self.sql(case)

View file

@ -71,7 +71,7 @@ def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str:
multiplier *= -1
if expression.expression.is_number:
modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
modified_increment = exp.Literal.number(expression.expression.to_py() * multiplier)
else:
modified_increment = expression.expression
if multiplier != 1:
@ -446,12 +446,13 @@ class Hive(Dialect):
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
SUPPORTS_TO_NUMBER = False
WITH_PROPERTIES_PREFIX = "TBLPROPERTIES"
PARSE_JSON_NAME = None
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
exp.Select,
exp.Subquery,
exp.Union,
exp.SetOperation,
}
SUPPORTED_JSON_PATH_PARTS = {
@ -575,7 +576,6 @@ class Hive(Dialect):
exp.NotForReplicationColumnConstraint: lambda *_: "",
exp.OnProperty: lambda *_: "",
exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY",
exp.ParseJSON: lambda self, e: self.sql(e.this),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),

View file

@ -689,6 +689,7 @@ class MySQL(Dialect):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
SUPPORTS_TO_NUMBER = False
PARSE_JSON_NAME = None
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -714,7 +715,6 @@ class MySQL(Dialect):
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[
@ -1093,29 +1093,6 @@ class MySQL(Dialect):
"xor",
"year_month",
"zerofill",
"cume_dist",
"dense_rank",
"empty",
"except",
"first_value",
"grouping",
"groups",
"intersect",
"json_table",
"lag",
"last_value",
"lateral",
"lead",
"nth_value",
"ntile",
"of",
"over",
"percent_rank",
"rank",
"recursive",
"row_number",
"system",
"window",
}
def array_sql(self, expression: exp.Array) -> str:

View file

@ -33,171 +33,11 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)
def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope
"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast
def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")
def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)
def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""
# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None
left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]
left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0
if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None
join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")
if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope
def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return
# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}
for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent
join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)
_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True
SUPPORTS_COLUMN_JOIN_MARKS = True
NULL_ORDERING = "nulls_are_large"
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@ -267,6 +107,7 @@ class Oracle(Dialect):
"TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "oracle"),
"TO_DATE": build_formatted_time(exp.StrToDate, "oracle"),
}
FUNCTIONS.pop("NVL")
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
**parser.Parser.FUNCTION_PARSERS,
@ -282,13 +123,6 @@ class Oracle(Dialect):
"XMLTABLE": lambda self: self._parse_xml_table(),
}
NO_PAREN_FUNCTION_PARSERS = {
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
"CONNECT_BY_ROOT": lambda self: self.expression(
exp.ConnectByRoot, this=self._parse_column()
),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"GLOBAL": lambda self: self._match_text_seq("TEMPORARY")
@ -408,7 +242,6 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),

View file

@ -56,12 +56,15 @@ def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB]
this = self.sql(expression, "this")
unit = expression.args.get("unit")
expression = self._simplify_unless_literal(expression.expression)
if not isinstance(expression, exp.Literal):
e = self._simplify_unless_literal(expression.expression)
if isinstance(e, exp.Literal):
e.args["is_string"] = True
elif e.is_number:
e = exp.Literal.string(e.to_py())
else:
self.unsupported("Cannot add non literal")
expression.args["is_string"] = True
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
return f"{this} {kind} {self.sql(exp.Interval(this=e, unit=unit))}"
return func
@ -331,6 +334,7 @@ class Postgres(Dialect):
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
"FLOAT": TokenType.DOUBLE,
}
KEYWORDS.pop("/*+")
KEYWORDS.pop("DIV")
SINGLE_TOKENS = {

View file

@ -173,6 +173,35 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
def _jsonextract_sql(self: Presto.Generator, expression: exp.JSONExtract) -> str:
is_json_extract = self.dialect.settings.get("variant_extract_is_json_extract", True)
# Generate JSON_EXTRACT unless the user has configured that a Snowflake / Databricks
# VARIANT extract (e.g. col:x.y) should map to dot notation (i.e ROW access) in Presto/Trino
if not expression.args.get("variant_extract") or is_json_extract:
return self.func(
"JSON_EXTRACT", expression.this, expression.expression, *expression.expressions
)
this = self.sql(expression, "this")
# Convert the JSONPath extraction `JSON_EXTRACT(col, '$.x.y) to a ROW access col.x.y
segments = []
for path_key in expression.expression.expressions[1:]:
if not isinstance(path_key, exp.JSONPathKey):
# Cannot transpile subscripts, wildcards etc to dot notation
self.unsupported(f"Cannot transpile JSONPath segment '{path_key}' to ROW access")
continue
key = path_key.this
if not exp.SAFE_IDENTIFIER_RE.match(key):
key = f'"{key}"'
segments.append(f".{key}")
expr = "".join(segments)
return f"{this}{expr}"
def _to_int(expression: exp.Expression) -> exp.Expression:
if not expression.type:
from sqlglot.optimizer.annotate_types import annotate_types
@ -227,7 +256,7 @@ class Presto(Dialect):
"TDIGEST": TokenType.TDIGEST,
"HYPERLOGLOG": TokenType.HLLSKETCH,
}
KEYWORDS.pop("/*+")
KEYWORDS.pop("QUALIFY")
class Parser(parser.Parser):
@ -305,6 +334,7 @@ class Presto(Dialect):
MULTI_ARG_DISTINCT = False
SUPPORTS_TO_NUMBER = False
HEX_FUNC = "TO_HEX"
PARSE_JSON_NAME = "JSON_PARSE"
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -389,7 +419,7 @@ class Presto(Dialect):
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.JSONExtract: _jsonextract_sql,
exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
@ -448,9 +478,6 @@ class Presto(Dialect):
[transforms.remove_within_group_for_percentiles]
),
exp.Xor: bool_xor_sql,
exp.MD5: lambda self, e: self.func(
"LOWER", self.func("TO_HEX", self.func("MD5", self.sql(e, "this")))
),
exp.MD5Digest: rename_func("MD5"),
exp.SHA: rename_func("SHA1"),
exp.SHA2: sha256_sql,
@ -517,6 +544,19 @@ class Presto(Dialect):
"with",
}
def md5_sql(self, expression: exp.MD5) -> str:
this = expression.this
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
this = annotate_types(this)
if this.is_type(*exp.DataType.TEXT_TYPES):
this = exp.Encode(this=this, charset=exp.Literal.string("utf-8"))
return self.func("LOWER", self.func("TO_HEX", self.func("MD5", self.sql(this))))
def strtounix_sql(self, expression: exp.StrToUnix) -> str:
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
json_extract_segments,
no_tablesample_sql,
rename_func,
map_date_part,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
@ -23,7 +24,11 @@ if t.TYPE_CHECKING:
def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _builder(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
expr = expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=map_date_part(seq_get(args, 0)),
)
if expr_type is exp.TsOrDsAdd:
expr.set("return_type", exp.DataType.build("TIMESTAMP"))
@ -40,7 +45,6 @@ class Redshift(Postgres):
INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True
SUPPORTS_COLUMN_JOIN_MARKS = True
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
@ -148,6 +152,8 @@ class Redshift(Postgres):
MULTI_ARG_DISTINCT = True
COPY_PARAMS_ARE_WRAPPED = False
HEX_FUNC = "TO_HEX"
PARSE_JSON_NAME = "JSON_PARSE"
# Redshift doesn't have `WITH` as part of their with_properties so we remove it
WITH_PROPERTIES_PREFIX = " "
@ -180,7 +186,6 @@ class Redshift(Postgres):
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@ -203,13 +208,14 @@ class Redshift(Postgres):
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
TRANSFORMS.pop(exp.Pivot)
# Postgres doesn't support JSON_PARSE, but Redshift does
TRANSFORMS.pop(exp.ParseJSON)
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
# Redshift supports ANY_VALUE(..)
# Redshift supports these functions
TRANSFORMS.pop(exp.AnyValue)
# Redshift supports LAST_DAY(..)
TRANSFORMS.pop(exp.LastDay)
TRANSFORMS.pop(exp.SHA2)

View file

@ -21,6 +21,7 @@ from sqlglot.dialects.dialect import (
timestamptrunc_sql,
timestrtotime_sql,
var_map_sql,
map_date_part,
)
from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
@ -75,7 +76,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
def _build_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0))
)
@ -84,7 +85,7 @@ def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
return expr_type(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
unit=map_date_part(seq_get(args, 0)),
)
return _builder
@ -143,97 +144,9 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
return _parse
DATE_PART_MAPPING = {
"Y": "YEAR",
"YY": "YEAR",
"YYY": "YEAR",
"YYYY": "YEAR",
"YR": "YEAR",
"YEARS": "YEAR",
"YRS": "YEAR",
"MM": "MONTH",
"MON": "MONTH",
"MONS": "MONTH",
"MONTHS": "MONTH",
"D": "DAY",
"DD": "DAY",
"DAYS": "DAY",
"DAYOFMONTH": "DAY",
"WEEKDAY": "DAYOFWEEK",
"DOW": "DAYOFWEEK",
"DW": "DAYOFWEEK",
"WEEKDAY_ISO": "DAYOFWEEKISO",
"DOW_ISO": "DAYOFWEEKISO",
"DW_ISO": "DAYOFWEEKISO",
"YEARDAY": "DAYOFYEAR",
"DOY": "DAYOFYEAR",
"DY": "DAYOFYEAR",
"W": "WEEK",
"WK": "WEEK",
"WEEKOFYEAR": "WEEK",
"WOY": "WEEK",
"WY": "WEEK",
"WEEK_ISO": "WEEKISO",
"WEEKOFYEARISO": "WEEKISO",
"WEEKOFYEAR_ISO": "WEEKISO",
"Q": "QUARTER",
"QTR": "QUARTER",
"QTRS": "QUARTER",
"QUARTERS": "QUARTER",
"H": "HOUR",
"HH": "HOUR",
"HR": "HOUR",
"HOURS": "HOUR",
"HRS": "HOUR",
"M": "MINUTE",
"MI": "MINUTE",
"MIN": "MINUTE",
"MINUTES": "MINUTE",
"MINS": "MINUTE",
"S": "SECOND",
"SEC": "SECOND",
"SECONDS": "SECOND",
"SECS": "SECOND",
"MS": "MILLISECOND",
"MSEC": "MILLISECOND",
"MILLISECONDS": "MILLISECOND",
"US": "MICROSECOND",
"USEC": "MICROSECOND",
"MICROSECONDS": "MICROSECOND",
"NS": "NANOSECOND",
"NSEC": "NANOSECOND",
"NANOSEC": "NANOSECOND",
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
"EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
"EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
"TZH": "TIMEZONE_HOUR",
"TZM": "TIMEZONE_MINUTE",
}
@t.overload
def _map_date_part(part: exp.Expression) -> exp.Var:
pass
@t.overload
def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
pass
def _map_date_part(part):
mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
return exp.var(mapped) if mapped else part
def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
trunc = date_trunc_to_time(args)
trunc.set("unit", _map_date_part(trunc.args["unit"]))
trunc.set("unit", map_date_part(trunc.args["unit"]))
return trunc
@ -328,7 +241,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
DEFAULT_SAMPLING_METHOD = "BERNOULLI"
COLON_IS_JSON_EXTRACT = True
COLON_IS_VARIANT_EXTRACT = True
ID_VAR_TOKENS = {
*parser.Parser.ID_VAR_TOKENS,
@ -367,8 +280,10 @@ class Snowflake(Dialect):
),
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1))
),
"LEN": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
"LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
"LISTAGG": exp.GroupConcat.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
@ -385,6 +300,7 @@ class Snowflake(Dialect):
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": build_timestamp_from_parts,
"TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True),
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
@ -541,7 +457,7 @@ class Snowflake(Dialect):
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
this = _map_date_part(this)
this = map_date_part(this)
name = this.name.upper()
if name.startswith("EPOCH"):
@ -588,10 +504,11 @@ class Snowflake(Dialect):
return lateral
def _parse_at_before(self, table: exp.Table) -> exp.Table:
def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]:
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
index = self._index
if self._match_texts(("AT", "BEFORE")):
historical_data = None
if self._match_texts(self.HISTORICAL_DATA_PREFIX):
this = self._prev.text.upper()
kind = (
self._match(TokenType.L_PAREN)
@ -602,14 +519,27 @@ class Snowflake(Dialect):
if expression:
self._match_r_paren()
when = self.expression(
historical_data = self.expression(
exp.HistoricalData, this=this, kind=kind, expression=expression
)
table.set("when", when)
else:
self._retreat(index)
return table
return historical_data
def _parse_changes(self) -> t.Optional[exp.Changes]:
if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"):
return None
information = self._parse_var(any_token=True)
self._match_r_paren()
return self.expression(
exp.Changes,
information=information,
at_before=self._parse_historical_data(),
end=self._parse_historical_data(),
)
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False
@ -643,7 +573,15 @@ class Snowflake(Dialect):
else:
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
return self._parse_at_before(table)
changes = self._parse_changes()
if changes:
table.set("changes", changes)
at_before = self._parse_historical_data()
if at_before:
table.set("when", at_before)
return table
def _parse_id_var(
self,
@ -771,6 +709,7 @@ class Snowflake(Dialect):
"WAREHOUSE": TokenType.WAREHOUSE,
"STREAMLIT": TokenType.STREAMLIT,
}
KEYWORDS.pop("/*+")
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -839,6 +778,9 @@ class Snowflake(Dialect):
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ParseJSON: lambda self, e: self.func(
"TRY_PARSE_JSON" if e.args.get("safe") else "PARSE_JSON", e.this
),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]

View file

@ -33,7 +33,7 @@ def _build_datediff(args: t.List) -> exp.Expression:
expression = seq_get(args, 1)
if len(args) == 3:
unit = this
unit = exp.var(t.cast(exp.Expression, this).name)
this = args[2]
return exp.DateDiff(
@ -91,6 +91,8 @@ def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.Timestam
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
RAW_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
@ -105,6 +107,7 @@ class Spark(Spark2):
"DATEADD": _build_dateadd,
"TIMESTAMPADD": _build_dateadd,
"DATEDIFF": _build_datediff,
"DATE_DIFF": _build_datediff,
"TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
"TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
"TRY_ELEMENT_AT": lambda args: exp.Bracket(

View file

@ -106,11 +106,16 @@ class SQLite(Dialect):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = tokens.Tokenizer.KEYWORDS.copy()
KEYWORDS.pop("/*+")
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
"STRFTIME": _build_strftime,
"DATETIME": lambda args: exp.Anonymous(this="DATETIME", expressions=args),
"TIME": lambda args: exp.Anonymous(this="TIME", expressions=args),
}
STRING_ALIASES = True

View file

@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
rename_func,
to_number_with_nls_param,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -24,9 +25,9 @@ def _date_add_sql(
if not isinstance(value, exp.Literal):
self.unsupported("Cannot add non literal")
if value.is_negative:
if isinstance(value, exp.Neg):
kind_to_op = {"+": "-", "-": "+"}
value = exp.Literal.string(value.name[1:])
value = exp.Literal.string(value.this.to_py())
else:
kind_to_op = {"+": "+", "-": "-"}
value.set("is_string", True)
@ -96,6 +97,7 @@ class Teradata(Dialect):
"TOP": TokenType.TOP,
"UPD": TokenType.UPDATE,
}
KEYWORDS.pop("/*+")
# Teradata does not support % as a modulo operator
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
@ -159,6 +161,11 @@ class Teradata(Dialect):
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"RANDOM": lambda args: exp.Rand(lower=seq_get(args, 0), upper=seq_get(args, 1)),
}
EXPONENT = {
TokenType.DSTAR: exp.Pow,
}
@ -200,6 +207,14 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
def _parse_index_params(self) -> exp.IndexParameters:
this = super()._parse_index_params()
if this.args.get("on"):
this.set("on", None)
self._retreat(self._index - 2)
return this
class Generator(generator.Generator):
LIMIT_IS_TOP = True
JOIN_HINTS = False
@ -208,11 +223,13 @@ class Teradata(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
}
PROPERTIES_LOCATION = {
@ -230,6 +247,7 @@ class Teradata(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Pow: lambda self, e: self.binary(e, "**"),
exp.Rand: lambda self, e: self.func("RANDOM", e.args.get("lower"), e.args.get("upper")),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
@ -238,12 +256,15 @@ class Teradata(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)),
}
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
prefix, suffix = ("(", ")") if expression.this else ("", "")
return self.func("CURRENT_TIMESTAMP", expression.this, prefix=prefix, suffix=suffix)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"):
# We don't actually want to print the unknown type in CAST(<value> AS FORMAT <format>)

View file

@ -450,6 +450,7 @@ class TSQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CLUSTERED INDEX": TokenType.INDEX,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.DECLARE,
@ -457,6 +458,7 @@ class TSQL(Dialect):
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NONCLUSTERED INDEX": TokenType.INDEX,
"NTEXT": TokenType.TEXT,
"OPTION": TokenType.OPTION,
"OUTPUT": TokenType.RETURNING,
@ -475,6 +477,7 @@ class TSQL(Dialect):
"UPDATE STATISTICS": TokenType.COMMAND,
"XML": TokenType.XML,
}
KEYWORDS.pop("/*+")
COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END}
@ -533,6 +536,31 @@ class TSQL(Dialect):
TokenType.DECLARE: lambda self: self._parse_declare(),
}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.DCOLON: lambda self, this: self.expression(
exp.ScopeResolution,
this=this,
expression=self._parse_function() or self._parse_var(any_token=True),
),
}
# The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.DCOLON: lambda self, this, to: self.expression(exp.Cast, this=this, to=to)
if isinstance(to, exp.DataType) and to.this != exp.DataType.Type.USERDEFINED
else self.expression(exp.ScopeResolution, this=this, expression=to),
}
def _parse_dcolon(self) -> t.Optional[exp.Expression]:
# We want to use _parse_types() if the first token after :: is a known type,
# otherwise we could parse something like x::varchar(max) into a function
if self._match_set(self.TYPE_TOKENS, advance=False):
return self._parse_types()
return self._parse_function() or self._parse_types()
def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.OPTION):
return None
@ -757,12 +785,15 @@ class TSQL(Dialect):
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False
OUTER_UNION_MODIFIERS = False
SET_OP_MODIFIERS = False
COPY_PARAMS_EQ_REQUIRED = True
PARSE_JSON_NAME = None
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
exp.Insert,
exp.Intersect,
exp.Except,
exp.Merge,
exp.Select,
exp.Subquery,
@ -816,7 +847,6 @@ class TSQL(Dialect):
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Repeat: rename_func("REPLICATE"),
exp.Select: transforms.preprocess(
[
@ -850,6 +880,9 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def scope_resolution(self, rhs: str, scope_name: str) -> str:
return f"{scope_name}::{rhs}"
def select_sql(self, expression: exp.Select) -> str:
if expression.args.get("offset"):
if not expression.args.get("order"):

View file

@ -447,6 +447,7 @@ class Python(Dialect):
exp.Is: lambda self, e: (
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
),
exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions),
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",

View file

@ -20,6 +20,7 @@ import textwrap
import typing as t
from collections import deque
from copy import deepcopy
from decimal import Decimal
from enum import auto
from functools import reduce
@ -29,7 +30,6 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_collection,
ensure_list,
is_int,
seq_get,
subclasses,
)
@ -40,6 +40,7 @@ if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Q = t.TypeVar("Q", bound="Query")
S = t.TypeVar("S", bound="SetOperation")
class _Expression(type):
@ -174,23 +175,22 @@ class Expression(metaclass=_Expression):
"""
Checks whether a Literal expression is a number.
"""
return isinstance(self, Literal) and not self.args["is_string"]
return (isinstance(self, Literal) and not self.args["is_string"]) or (
isinstance(self, Neg) and self.this.is_number
)
@property
def is_negative(self) -> bool:
def to_py(self) -> t.Any:
"""
Checks whether an expression is negative.
Handles both exp.Neg and Literal numbers with "-" which come from optimizer.simplify.
Returns a Python object equivalent of the SQL node.
"""
return isinstance(self, Neg) or (self.is_number and self.this.startswith("-"))
raise ValueError(f"{self} cannot be converted to a Python object.")
@property
def is_int(self) -> bool:
"""
Checks whether a Literal expression is an integer.
Checks whether an expression is an integer.
"""
return self.is_number and is_int(self.name)
return self.is_number and isinstance(self.to_py(), int)
@property
def is_star(self) -> bool:
@ -2002,6 +2002,10 @@ class Check(Expression):
pass
class Changes(Expression):
arg_types = {"information": True, "at_before": False, "end": False}
# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
class Connect(Expression):
arg_types = {"start": False, "connect": True, "nocycle": False}
@ -2127,6 +2131,7 @@ class IndexParameters(Expression):
"partition_by": False,
"tablespace": False,
"where": False,
"on": False,
}
@ -2281,6 +2286,14 @@ class Literal(Condition):
def output_name(self) -> str:
return self.name
def to_py(self) -> int | str | Decimal:
if self.is_number:
try:
return int(self.this)
except ValueError:
return Decimal(self.this)
return self.this
class Join(Expression):
arg_types = {
@ -2639,6 +2652,10 @@ class DictRange(Property):
arg_types = {"this": True, "min": True, "max": True}
class DynamicProperty(Property):
arg_types = {}
# Clickhouse CREATE ... ON CLUSTER modifier
# https://clickhouse.com/docs/en/sql-reference/distributed-ddl
class OnCluster(Property):
@ -2805,6 +2822,10 @@ class TemporaryProperty(Property):
arg_types = {"this": False}
class SecureProperty(Property):
arg_types = {}
class TransformModelProperty(Property):
arg_types = {"expressions": True}
@ -2834,6 +2855,10 @@ class WithJournalTableProperty(Property):
arg_types = {"this": True}
class WithSchemaBindingProperty(Property):
arg_types = {"this": True}
class WithSystemVersioningProperty(Property):
arg_types = {
"on": False,
@ -3017,6 +3042,7 @@ class Table(Expression):
"when": False,
"only": False,
"partition": False,
"changes": False,
}
@property
@ -3065,7 +3091,7 @@ class Table(Expression):
return col
class Union(Query):
class SetOperation(Query):
arg_types = {
"with": False,
"this": True,
@ -3076,13 +3102,13 @@ class Union(Query):
}
def select(
self,
self: S,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Union:
) -> S:
this = maybe_copy(self, copy)
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
this.expression.unnest().select(
@ -3111,11 +3137,15 @@ class Union(Query):
return self.expression
class Except(Union):
class Union(SetOperation):
pass
class Intersect(Union):
class Except(SetOperation):
pass
class Intersect(SetOperation):
pass
@ -3727,7 +3757,7 @@ class Select(Query):
return self.expressions
UNWRAPPED_QUERIES = (Select, Union)
UNWRAPPED_QUERIES = (Select, SetOperation)
class Subquery(DerivedTable, Query):
@ -3893,9 +3923,13 @@ class Null(Condition):
def name(self) -> str:
return "NULL"
def to_py(self) -> Lit[None]:
return None
class Boolean(Condition):
pass
def to_py(self) -> bool:
return self.this
class DataTypeParam(Expression):
@ -4019,6 +4053,7 @@ class DataType(Expression):
VARBINARY = auto()
VARCHAR = auto()
VARIANT = auto()
VECTOR = auto()
XML = auto()
YEAR = auto()
TDIGEST = auto()
@ -4473,7 +4508,10 @@ class Paren(Unary):
class Neg(Unary):
pass
def to_py(self) -> int | Decimal:
if self.is_number:
return self.this.to_py() * -1
return super().to_py()
class Alias(Expression):
@ -5065,6 +5103,12 @@ class DateTrunc(Func):
return self.args["unit"]
# https://cloud.google.com/bigquery/docs/reference/standard-sql/datetime_functions#datetime
# expression can either be time_expr or time_zone
class Datetime(Func):
arg_types = {"this": True, "expression": False}
class DatetimeAdd(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}
@ -5115,7 +5159,7 @@ class Extract(Func):
class Timestamp(Func):
arg_types = {"this": False, "expression": False, "with_tz": False}
arg_types = {"this": False, "zone": False, "with_tz": False}
class TimestampAdd(Func, TimeUnit):
@ -5441,12 +5485,18 @@ class OpenJSON(Func):
arg_types = {"this": True, "path": False, "expressions": False}
class JSONBContains(Binary):
class JSONBContains(Binary, Func):
_sql_names = ["JSONB_CONTAINS"]
class JSONExtract(Binary, Func):
arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False}
arg_types = {
"this": True,
"expression": True,
"only_json_types": False,
"expressions": False,
"variant_extract": False,
}
_sql_names = ["JSON_EXTRACT"]
is_var_len_args = True
@ -5485,9 +5535,9 @@ class JSONArrayContains(Binary, Predicate, Func):
class ParseJSON(Func):
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
# Snowflake also has TRY_PARSE_JSON, which is represented using `safe`
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
arg_types = {"this": True, "expression": False, "safe": False}
class Least(Func):
@ -5504,6 +5554,7 @@ class Right(Func):
class Length(Func):
arg_types = {"this": True, "binary": False}
_sql_names = ["LENGTH", "LEN"]
@ -5560,6 +5611,11 @@ class MapFromEntries(Func):
pass
# https://learn.microsoft.com/en-us/sql/t-sql/language-elements/scope-resolution-operator-transact-sql?view=sql-server-ver16
class ScopeResolution(Expression):
arg_types = {"this": False, "expression": True}
class StarMap(Func):
pass
@ -5642,9 +5698,11 @@ class Quarter(Func):
pass
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions/RANDOM/RANDOM-Function-Syntax
# teradata lower and upper bounds
class Rand(Func):
_sql_names = ["RAND", "RANDOM"]
arg_types = {"this": False}
arg_types = {"this": False, "lower": False, "upper": False}
class Randn(Func):
@ -5765,11 +5823,11 @@ class StrPosition(Func):
class StrToDate(Func):
arg_types = {"this": True, "format": False}
arg_types = {"this": True, "format": False, "safe": False}
class StrToTime(Func):
arg_types = {"this": True, "format": True, "zone": False}
arg_types = {"this": True, "format": True, "zone": False, "safe": False}
# Spark allows unix_timestamp()
@ -5833,6 +5891,11 @@ class StddevSamp(AggFunc):
pass
# https://cloud.google.com/bigquery/docs/reference/standard-sql/time_functions#time
class Time(Func):
arg_types = {"this": False, "zone": False}
class TimeToStr(Func):
arg_types = {"this": True, "format": True, "culture": False, "timezone": False}

View file

@ -87,9 +87,11 @@ class Generator(metaclass=_Generator):
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda *_: "COPY GRANTS",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.DynamicProperty: lambda *_: "DYNAMIC",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.EphemeralColumnConstraint: lambda self,
e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}",
@ -131,6 +133,7 @@ class Generator(metaclass=_Generator):
"RETURNS NULL ON NULL INPUT" if e.args.get("null") else self.naked_property(e)
),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SecureProperty: lambda *_: "SECURE",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
@ -143,7 +146,7 @@ class Generator(metaclass=_Generator):
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TagColumnConstraint: lambda self, e: f"TAG ({self.expressions(e, flat=True)})",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.args.get("zone")),
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
@ -154,6 +157,7 @@ class Generator(metaclass=_Generator):
exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}",
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}
@ -168,8 +172,8 @@ class Generator(metaclass=_Generator):
# Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all
EXPLICIT_UNION = False
# Always do <set op> distinct or <set op> all
EXPLICIT_SET_OP = False
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
@ -339,10 +343,10 @@ class Generator(metaclass=_Generator):
# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True
# Whether or not union modifiers apply to the outer union or select.
# Whether or not set op modifiers apply to the outer set op or select.
# SELECT * FROM x UNION SELECT * FROM y LIMIT 1
# True means limit 1 happens after the union, False means it it happens on y.
OUTER_UNION_MODIFIERS = True
# True means limit 1 happens after the set op, False means it it happens on y.
SET_OP_MODIFIERS = True
# Whether parameters from COPY statement are wrapped in parentheses
COPY_PARAMS_ARE_WRAPPED = True
@ -368,6 +372,12 @@ class Generator(metaclass=_Generator):
# The keywords to use when prefixing & separating WITH based properties
WITH_PROPERTIES_PREFIX = "WITH"
# Whether to quote the generated expression of exp.JsonPath
QUOTE_JSON_PATH = True
# The name to generate for the JSONPath expression. If `None`, only `this` will be generated
PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -430,6 +440,7 @@ class Generator(metaclass=_Generator):
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.DictRange: exp.Properties.Location.POST_SCHEMA,
exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
exp.DynamicProperty: exp.Properties.Location.POST_CREATE,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
@ -469,6 +480,7 @@ class Generator(metaclass=_Generator):
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SampleProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SecureProperty: exp.Properties.Location.POST_CREATE,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.Set: exp.Properties.Location.POST_SCHEMA,
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
@ -491,6 +503,7 @@ class Generator(metaclass=_Generator):
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA,
exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA,
}
@ -506,7 +519,7 @@ class Generator(metaclass=_Generator):
exp.Insert,
exp.Join,
exp.Select,
exp.Union,
exp.SetOperation,
exp.Update,
exp.Where,
exp.With,
@ -515,7 +528,7 @@ class Generator(metaclass=_Generator):
# Expressions that should not have their comments generated in maybe_comment
EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Binary,
exp.Union,
exp.SetOperation,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
@ -1298,8 +1311,10 @@ class Generator(metaclass=_Generator):
with_storage = f" WITH ({with_storage})" if with_storage else ""
tablespace = self.sql(expression, "tablespace")
tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""
on = self.sql(expression, "on")
on = f" ON {on}" if on else ""
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"
return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}{on}"
def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
@ -1736,7 +1751,10 @@ class Generator(metaclass=_Generator):
if when:
table = f"{table} {when}"
return f"{only}{table}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
changes = self.sql(expression, "changes")
changes = f" {changes}" if changes else ""
return f"{only}{table}{changes}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self,
@ -2393,8 +2411,8 @@ class Generator(metaclass=_Generator):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
def set_operations(self, expression: exp.Union) -> str:
if not self.OUTER_UNION_MODIFIERS:
def set_operations(self, expression: exp.SetOperation) -> str:
if not self.SET_OP_MODIFIERS:
limit = expression.args.get("limit")
order = expression.args.get("order")
@ -2413,7 +2431,7 @@ class Generator(metaclass=_Generator):
while stack:
node = stack.pop()
if isinstance(node, exp.Union):
if isinstance(node, exp.SetOperation):
stack.append(node.expression)
stack.append(
self.maybe_comment(
@ -2433,8 +2451,8 @@ class Generator(metaclass=_Generator):
def union_sql(self, expression: exp.Union) -> str:
return self.set_operations(expression)
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
def union_op(self, expression: exp.SetOperation) -> str:
kind = " DISTINCT" if self.EXPLICIT_SET_OP else ""
kind = kind if expression.args.get("distinct") else " ALL"
by_name = " BY NAME" if expression.args.get("by_name") else ""
return f"UNION{kind}{by_name}"
@ -2653,7 +2671,10 @@ class Generator(metaclass=_Generator):
def jsonpath_sql(self, expression: exp.JSONPath) -> str:
path = self.expressions(expression, sep="", flat=True).lstrip(".")
return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
if self.QUOTE_JSON_PATH:
path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
return path
def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
if isinstance(expression, exp.JSONPathPart):
@ -3969,3 +3990,56 @@ class Generator(metaclass=_Generator):
this = self.sql(expression, "this")
this = f"TABLE {this}"
return self.func("GAP_FILL", this, *[v for k, v in expression.args.items() if k != "this"])
def scope_resolution(self, rhs: str, scope_name: str) -> str:
return self.func("SCOPE_RESOLUTION", scope_name or None, rhs)
def scoperesolution_sql(self, expression: exp.ScopeResolution) -> str:
this = self.sql(expression, "this")
expr = expression.expression
if isinstance(expr, exp.Func):
# T-SQL's CLR functions are case sensitive
expr = f"{self.sql(expr, 'this')}({self.format_args(*expr.expressions)})"
else:
expr = self.sql(expression, "expression")
return self.scope_resolution(expr, this)
def parsejson_sql(self, expression: exp.ParseJSON) -> str:
if self.PARSE_JSON_NAME is None:
return self.sql(expression.this)
return self.func(self.PARSE_JSON_NAME, expression.this, expression.expression)
def length_sql(self, expression: exp.Length) -> str:
return self.func("LENGTH", expression.this)
def rand_sql(self, expression: exp.Rand) -> str:
lower = self.sql(expression, "lower")
upper = self.sql(expression, "upper")
if lower and upper:
return f"({upper} - {lower}) * {self.func('RAND', expression.this)} + {lower}"
return self.func("RAND", expression.this)
def strtodate_sql(self, expression: exp.StrToDate) -> str:
return self.func("STR_TO_DATE", expression.this, expression.args.get("format"))
def strtotime_sql(self, expression: exp.StrToTime) -> str:
return self.func(
"STR_TO_TIME",
expression.this,
expression.args.get("format"),
expression.args.get("zone"),
)
def changes_sql(self, expression: exp.Changes) -> str:
information = self.sql(expression, "information")
information = f"INFORMATION => {information}"
at_before = self.sql(expression, "at_before")
at_before = f"{self.seg('')}{at_before}" if at_before else ""
end = self.sql(expression, "end")
end = f"{self.seg('')}{end}" if end else ""
return f"CHANGES ({information}){at_before}{end}"

View file

@ -8,6 +8,7 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import Lit
from sqlglot.dialects.dialect import DialectType
class JSONPathTokenizer(Tokenizer):
@ -36,9 +37,12 @@ class JSONPathTokenizer(Tokenizer):
STRING_ESCAPES = ["\\"]
def parse(path: str) -> exp.JSONPath:
def parse(path: str, dialect: DialectType = None) -> exp.JSONPath:
"""Takes in a JSON path string and parses it into a JSONPath expression."""
tokens = JSONPathTokenizer().tokenize(path)
from sqlglot.dialects import Dialect
jsonpath_tokenizer = Dialect.get_or_raise(dialect).jsonpath_tokenizer
tokens = jsonpath_tokenizer.tokenize(path)
size = len(tokens)
i = 0

View file

@ -152,8 +152,9 @@ def to_node(
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
if isinstance(scope.expression, exp.SetOperation):
name = type(scope.expression).__name__.upper()
upstream = upstream or Node(name=name, source=scope.expression, expression=select)
index = (
column

View file

@ -158,6 +158,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.Datetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
@ -196,6 +197,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.JSON: {
exp.ParseJSON,
},
exp.DataType.Type.TIME: {
exp.Time,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,

View file

@ -42,7 +42,7 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
and not node.args.get("zone")
):
return exp.cast(node.this, to=exp.DataType.Type.DATE)
if isinstance(node, exp.Timestamp) and not node.expression:
if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
if not node.type:
from sqlglot.optimizer.annotate_types import annotate_types

View file

@ -47,7 +47,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
if isinstance(scope.expression, exp.SetOperation):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections

View file

@ -60,8 +60,12 @@ def qualify_columns(
_pop_table_column_aliases(scope.derived_tables)
using_column_tables = _expand_using(scope, resolver)
if schema.empty and expand_alias_refs:
_expand_alias_refs(scope, resolver)
if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs:
_expand_alias_refs(
scope,
resolver,
expand_only_groupby=dialect.EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY,
)
_convert_columns_to_dots(scope, resolver)
_qualify_columns(scope, resolver)
@ -148,7 +152,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
for join in joins:
for i, join in enumerate(joins):
using = join.args.get("using")
if not using:
@ -168,6 +172,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
for identifier in using:
identifier = identifier.name
@ -178,9 +183,21 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
raise OptimizeError(f"Cannot automatically join: {identifier}")
table = table or source_table
conditions.append(
exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
)
if i == 0 or using_identifier_count == 1:
lhs: exp.Expression = exp.column(identifier, table=table)
else:
coalesce_columns = [
exp.column(identifier, table=t)
for t in ordered[:-1]
if identifier in resolver.get_source_columns(t)
]
if len(coalesce_columns) > 1:
lhs = exp.func("coalesce", *coalesce_columns)
else:
lhs = exp.column(identifier, table=table)
conditions.append(lhs.eq(exp.column(identifier, table=join_table)))
# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
@ -196,8 +213,8 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
for column in scope.columns:
if not column.table and column.name in column_tables:
tables = column_tables[column.name]
coalesce = [exp.column(column.name, table=table) for table in tables]
replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
coalesce_args = [exp.column(column.name, table=table) for table in tables]
replacement = exp.func("coalesce", *coalesce_args)
# Ensure selects keep their output name
if isinstance(column.parent, exp.Select):
@ -208,7 +225,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
return column_tables
def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
def _expand_alias_refs(scope: Scope, resolver: Resolver, expand_only_groupby: bool = False) -> None:
expression = scope.expression
if not isinstance(expression, exp.Select):
@ -219,7 +236,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
def replace_columns(
node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
) -> None:
if not node:
if not node or (expand_only_groupby and not isinstance(node, exp.Group)):
return
for column in walk_in_scope(node, prune=lambda node: node.is_star):
@ -583,14 +600,10 @@ def _expand_stars(
if name in using_column_tables and table in using_column_tables[name]:
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
coalesce_args = [exp.column(name, table=table) for table in tables]
new_selections.append(
alias(
exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
alias=name,
copy=False,
)
alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False)
)
else:
alias_ = replace_columns.get(table_id, {}).get(name, name)
@ -719,6 +732,7 @@ class Resolver:
self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None
self._all_columns: t.Optional[t.Set[str]] = None
self._infer_schema = infer_schema
self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {}
def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
"""
@ -771,41 +785,49 @@ class Resolver:
def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]:
"""Resolve the source columns for a given source `name`."""
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
cache_key = (name, only_visible)
if cache_key not in self._get_source_columns_cache:
if name not in self.scope.sources:
raise OptimizeError(f"Unknown table: {name}")
source = self.scope.sources[name]
source = self.scope.sources[name]
if isinstance(source, exp.Table):
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)):
columns = source.expression.named_selects
if isinstance(source, exp.Table):
columns = self.schema.column_names(source, only_visible)
elif isinstance(source, Scope) and isinstance(
source.expression, (exp.Values, exp.Unnest)
):
columns = source.expression.named_selects
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.schema.dialect == "bigquery":
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
columns.append(k.name)
else:
columns = source.expression.named_selects
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):
column_aliases = node.expression.alias_column_names
elif isinstance(node, exp.Expression):
column_aliases = node.alias_column_names
else:
column_aliases = []
node, _ = self.scope.selected_sources.get(name) or (None, None)
if isinstance(node, Scope):
column_aliases = node.expression.alias_column_names
elif isinstance(node, exp.Expression):
column_aliases = node.alias_column_names
else:
column_aliases = []
if column_aliases:
# If the source's columns are aliased, their aliases shadow the corresponding column names.
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
return [
alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
return columns
if column_aliases:
# If the source's columns are aliased, their aliases shadow the corresponding column names.
# This can be expensive if there are lots of columns, so only do this if column_aliases exist.
columns = [
alias or name
for (name, alias) in itertools.zip_longest(columns, column_aliases)
]
self._get_source_columns_cache[cache_key] = columns
return self._get_source_columns_cache[cache_key]
def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]:
if self._source_columns is None:

View file

@ -29,7 +29,7 @@ class Scope:
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
expression (exp.Select|exp.SetOperation): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
@ -233,7 +233,7 @@ class Scope:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
list[exp.Select | exp.SetOperation]: subqueries
"""
self._ensure_collected()
return self._subqueries
@ -339,7 +339,7 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
if isinstance(self.expression, exp.Union):
if isinstance(self.expression, exp.SetOperation):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
@ -535,7 +535,7 @@ def _traverse_scope(scope):
if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(expression, exp.Union):
elif isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
@ -588,7 +588,7 @@ def _traverse_union(scope):
scope_type=ScopeType.UNION,
)
if isinstance(expression, exp.Union):
if isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(new_scope)
union_scope_stack.append(new_scope)
@ -620,7 +620,7 @@ def _traverse_ctes(scope):
if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
if isinstance(union, exp.SetOperation):
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)
child_scope = None

View file

@ -6,7 +6,6 @@ import functools
import itertools
import typing as t
from collections import deque, defaultdict
from decimal import Decimal
from functools import reduce
import sqlglot
@ -347,8 +346,8 @@ def _simplify_comparison(expression, left, right, or_=False):
return expression
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
l = l.to_py()
r = r.to_py()
elif l.is_string and r.is_string:
l = l.name
r = r.name
@ -626,13 +625,8 @@ def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
if value[0] == "-":
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
return expression.this.this
if type(expression) in INVERSE_DATE_OPS:
return _simplify_binary(expression, expression.this, expression.interval()) or expression
@ -650,7 +644,7 @@ def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
this = expr.this
if isinstance(expr, exp.Cast) and this.is_int:
num = int(this.name)
num = this.to_py()
# Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
# integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
@ -690,8 +684,8 @@ def _simplify_binary(expression, a, b):
return exp.null()
if a.is_number and b.is_number:
num_a = int(a.name) if a.is_int else Decimal(a.name)
num_b = int(b.name) if b.is_int else Decimal(b.name)
num_a = a.to_py()
num_b = b.to_py()
if isinstance(expression, exp.Add):
return exp.Literal.number(num_a + num_b)
@ -1206,7 +1200,7 @@ def _is_date_literal(expression: exp.Expression) -> bool:
def extract_interval(expression):
try:
n = int(expression.name)
n = int(expression.this.to_py())
unit = expression.text("unit").lower()
return interval(unit, n)
except (UnsupportedUnit, ModuleNotFoundError, ValueError):

View file

@ -48,7 +48,7 @@ def unnest(select, parent_select, next_alias_name):
):
return
if isinstance(select, exp.Union):
if isinstance(select, exp.SetOperation):
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))
alias = next_alias_name()

View file

@ -150,6 +150,7 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
"HEX": build_hex,
"JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract),
"JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar),
"JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar),
@ -157,11 +158,16 @@ class Parser(metaclass=_Parser):
"LOG": build_logarithm,
"LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)),
"LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)),
"LOWER": build_lower,
"MOD": build_mod,
"SCOPE_RESOLUTION": lambda args: exp.ScopeResolution(expression=seq_get(args, 0))
if len(args) != 2
else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"TO_HEX": build_hex,
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
this=exp.Cast(
this=seq_get(args, 0),
@ -170,11 +176,9 @@ class Parser(metaclass=_Parser):
start=exp.Literal.number(1),
length=exp.Literal.number(10),
),
"VAR_MAP": build_var_map,
"LOWER": build_lower,
"UNNEST": lambda args: exp.Unnest(expressions=ensure_list(seq_get(args, 0))),
"UPPER": build_upper,
"HEX": build_hex,
"TO_HEX": build_hex,
"VAR_MAP": build_var_map,
}
NO_PAREN_FUNCTIONS = {
@ -295,6 +299,7 @@ class Parser(metaclass=_Parser):
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.VARIANT,
TokenType.VECTOR,
TokenType.OBJECT,
TokenType.OBJECT_IDENTIFIER,
TokenType.INET,
@ -670,6 +675,7 @@ class Parser(metaclass=_Parser):
exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
exp.Select: lambda self: self._parse_select(),
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
@ -818,6 +824,7 @@ class Parser(metaclass=_Parser):
"DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
"DYNAMIC": lambda self: self.expression(exp.DynamicProperty),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
@ -868,6 +875,7 @@ class Parser(metaclass=_Parser):
"SAMPLE": lambda self: self.expression(
exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise()
),
"SECURE": lambda self: self.expression(exp.SecureProperty),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SETTINGS": lambda self: self.expression(
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
@ -990,6 +998,9 @@ class Parser(metaclass=_Parser):
NO_PAREN_FUNCTION_PARSERS = {
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
"CASE": lambda self: self._parse_case(),
"CONNECT_BY_ROOT": lambda self: self.expression(
exp.ConnectByRoot, this=self._parse_column()
),
"IF": lambda self: self._parse_if(),
"NEXT": lambda self: self._parse_next_value_for(),
}
@ -1118,9 +1129,15 @@ class Parser(metaclass=_Parser):
CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",))
SCHEMA_BINDING_OPTIONS: OPTIONS_TYPE = {
"TYPE": ("EVOLUTION",),
**dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()),
}
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
HISTORICAL_DATA_PREFIX = {"AT", "BEFORE", "END"}
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"}
@ -1184,8 +1201,8 @@ class Parser(metaclass=_Parser):
STRING_ALIASES = False
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
MODIFIERS_ATTACHED_TO_SET_OP = True
SET_OP_MODIFIERS = {"order", "limit", "offset"}
# Whether to parse IF statements that aren't followed by a left parenthesis as commands
NO_PAREN_IF_COMMANDS = True
@ -1193,8 +1210,8 @@ class Parser(metaclass=_Parser):
# Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres)
JSON_ARROWS_REQUIRE_JSON_TYPE = False
# Whether the `:` operator is used to extract a value from a JSON document
COLON_IS_JSON_EXTRACT = False
# Whether the `:` operator is used to extract a value from a VARIANT column
COLON_IS_VARIANT_EXTRACT = False
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
# If this is True and '(' is not found, the keyword will be treated as an identifier
@ -1466,9 +1483,9 @@ class Parser(metaclass=_Parser):
def _try_parse(self, parse_method: t.Callable[[], T], retreat: bool = False) -> t.Optional[T]:
"""
Attemps to backtrack if a parse function that contains a try/catch internally raises an error. This behavior can
be different depending on the uset-set ErrorLevel, so _try_parse aims to solve this by setting & resetting
the parser state accordingly
Attemps to backtrack if a parse function that contains a try/catch internally raises an error.
This behavior can be different depending on the uset-set ErrorLevel, so _try_parse aims to
solve this by setting & resetting the parser state accordingly
"""
index = self._index
error_level = self.error_level
@ -2005,6 +2022,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.SERDE_PROPERTIES, advance=False):
return self._parse_serde_properties(with_=True)
if self._match(TokenType.SCHEMA):
return self.expression(
exp.WithSchemaBindingProperty,
this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS),
)
if not self._next:
return None
@ -2899,7 +2922,7 @@ class Parser(metaclass=_Parser):
continue
break
if self.SUPPORTS_IMPLICIT_UNNEST and this and "from" in this.args:
if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from"):
this = self._implicit_unnests_to_explicit(this)
return this
@ -3187,6 +3210,8 @@ class Parser(metaclass=_Parser):
)
where = self._parse_where()
on = self._parse_field() if self._match(TokenType.ON) else None
return self.expression(
exp.IndexParameters,
using=using,
@ -3196,6 +3221,7 @@ class Parser(metaclass=_Parser):
where=where,
with_storage=with_storage,
tablespace=tablespace,
on=on,
)
def _parse_index(
@ -3959,7 +3985,7 @@ class Parser(metaclass=_Parser):
token_type = self._prev.token_type
if token_type == TokenType.UNION:
operation = exp.Union
operation: t.Type[exp.SetOperation] = exp.Union
elif token_type == TokenType.EXCEPT:
operation = exp.Except
else:
@ -3979,11 +4005,11 @@ class Parser(metaclass=_Parser):
expression=expression,
)
if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP:
expression = this.expression
if expression:
for arg in self.UNION_MODIFIERS:
for arg in self.SET_OP_MODIFIERS:
expr = expression.args.get(arg)
if expr:
this.set(arg, expr.pop())
@ -4122,7 +4148,7 @@ class Parser(metaclass=_Parser):
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
if this and this.is_number:
this = exp.Literal.string(this.name)
this = exp.Literal.string(this.to_py())
elif this and this.is_string:
parts = exp.INTERVAL_STRING_RE.findall(this.name)
if len(parts) == 1:
@ -4286,8 +4312,8 @@ class Parser(metaclass=_Parser):
identifier = allow_identifiers and self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
if identifier:
tokens = self.dialect.tokenize(identifier.name)
if isinstance(identifier, exp.Identifier):
tokens = self.dialect.tokenize(identifier.sql(dialect=self.dialect))
if len(tokens) != 1:
self.raise_error("Unexpected identifier", self._prev)
@ -4370,6 +4396,10 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(self._parse_type_size)
# https://docs.snowflake.com/en/sql-reference/data-types-vector
if type_token == TokenType.VECTOR and len(expressions) == 2:
expressions[0] = exp.DataType.build(expressions[0].name, dialect=self.dialect)
if not expressions or not self._match(TokenType.R_PAREN):
self._retreat(index)
return None
@ -4481,10 +4511,22 @@ class Parser(metaclass=_Parser):
def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
index = self._index
this = (
self._parse_type(parse_interval=False, fallback_to_identifier=True)
or self._parse_id_var()
)
if (
self._curr
and self._next
and self._curr.token_type in self.TYPE_TOKENS
and self._next.token_type in self.TYPE_TOKENS
):
# Takes care of special cases like `STRUCT<list ARRAY<...>>` where the identifier is also a
# type token. Without this, the list will be parsed as a type and we'll eventually crash
this = self._parse_id_var()
else:
this = (
self._parse_type(parse_interval=False, fallback_to_identifier=True)
or self._parse_id_var()
)
self._match(TokenType.COLON)
if (
@ -4527,7 +4569,7 @@ class Parser(metaclass=_Parser):
return this
def _parse_colon_as_json_extract(
def _parse_colon_as_variant_extract(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
casts = []
@ -4560,11 +4602,14 @@ class Parser(metaclass=_Parser):
if path:
json_path.append(self._find_sql(self._tokens[start_index], end_token))
# The VARIANT extract in Snowflake/Databricks is parsed as a JSONExtract; Snowflake uses the json_path in GET_PATH() while
# Databricks transforms it back to the colon/dot notation
if json_path:
this = self.expression(
exp.JSONExtract,
this=this,
expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
variant_extract=True,
)
while casts:
@ -4572,6 +4617,9 @@ class Parser(metaclass=_Parser):
return this
def _parse_dcolon(self) -> t.Optional[exp.Expression]:
return self._parse_types()
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@ -4580,7 +4628,7 @@ class Parser(metaclass=_Parser):
op = self.COLUMN_OPERATORS.get(op_token)
if op_token == TokenType.DCOLON:
field = self._parse_types()
field = self._parse_dcolon()
if not field:
self.raise_error("Expected type")
elif op and self._curr:
@ -4618,7 +4666,7 @@ class Parser(metaclass=_Parser):
this = self._parse_bracket(this)
return self._parse_colon_as_json_extract(this) if self.COLON_IS_JSON_EXTRACT else this
return self._parse_colon_as_variant_extract(this) if self.COLON_IS_VARIANT_EXTRACT else this
def _parse_primary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PRIMARY_PARSERS):
@ -5312,8 +5360,8 @@ class Parser(metaclass=_Parser):
order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order),
)
def _parse_extract(self) -> exp.Extract:
this = self._parse_function() or self._parse_var() or self._parse_type()
def _parse_extract(self) -> exp.Extract | exp.Anonymous:
this = self._parse_function() or self._parse_var_or_string(upper=True)
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
@ -5362,6 +5410,7 @@ class Parser(metaclass=_Parser):
self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE,
)
),
safe=safe,
)
if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
@ -5942,8 +5991,8 @@ class Parser(metaclass=_Parser):
return self._prev
return None
def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
return self._parse_var() or self._parse_string()
def _parse_var_or_string(self, upper: bool = False) -> t.Optional[exp.Expression]:
return self._parse_string() or self._parse_var(any_token=True, upper=upper)
def _parse_primary_or_var(self) -> t.Optional[exp.Expression]:
return self._parse_primary() or self._parse_var(any_token=True)

View file

@ -108,7 +108,7 @@ class Step:
if isinstance(expression, exp.Select) and from_:
step = Scan.from_expression(from_.this, ctes)
elif isinstance(expression, exp.Union):
elif isinstance(expression, exp.SetOperation):
step = SetOperation.from_expression(expression, ctes)
else:
step = Scan()
@ -124,13 +124,13 @@ class Step:
projections = [] # final selects in this chain of steps representing a select
operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
aggregations = set()
aggregations = {}
next_operand_name = name_sequence("_a_")
def extract_agg_operands(expression):
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations.add(expression)
aggregations[expression] = None
for agg in agg_funcs:
for operand in agg.unnest_operands():
@ -426,7 +426,7 @@ class SetOperation(Step):
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> SetOperation:
assert isinstance(expression, exp.Union)
assert isinstance(expression, exp.SetOperation)
left = Step.from_expression(expression.left, ctes)
# SELECT 1 UNION SELECT 2 <-- these subqueries don't have names

View file

@ -386,6 +386,8 @@ class MappingSchema(AbstractMappingSchema, Schema):
if not isinstance(columns, dict):
raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0])))
if not columns:
raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column")
if isinstance(first(columns.values()), dict):
raise SchemaError(
error_msg.format(

View file

@ -202,6 +202,7 @@ class TokenType(AutoName):
SIMPLEAGGREGATEFUNCTION = auto()
TDIGEST = auto()
UNKNOWN = auto()
VECTOR = auto()
# keywords
ALIAS = auto()
@ -526,6 +527,7 @@ class _Tokenizer(type):
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
},
heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
string_escapes_allowed_in_raw_strings=klass.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS,
)
token_types = RsTokenTypeSettings(
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
@ -602,6 +604,9 @@ class Tokenizer(metaclass=_Tokenizer):
# Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc
HEREDOC_STRING_ALTERNATIVE = TokenType.VAR
# Whether string escape characters function as such when placed within raw strings
STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = True
# Autofilled
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
@ -877,6 +882,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATERANGE": TokenType.DATERANGE,
"DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
"UNIQUE": TokenType.UNIQUE,
"VECTOR": TokenType.VECTOR,
"STRUCT": TokenType.STRUCT,
"SEQUENCE": TokenType.SEQUENCE,
"VARIANT": TokenType.VARIANT,
@ -1162,10 +1168,22 @@ class Tokenizer(metaclass=_Tokenizer):
# Skip the comment's start delimiter
self._advance(comment_start_size)
comment_count = 1
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
while not self._end:
if self._chars(comment_end_size) == comment_end:
comment_count -= 1
if not comment_count:
break
self._advance(alnum=True)
# Nested comments are allowed by some dialects, e.g. databricks, duckdb, postgres
if not self._end and self._chars(comment_end_size) == comment_start:
self._advance(comment_start_size)
comment_count += 1
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
@ -1280,7 +1298,7 @@ class Tokenizer(metaclass=_Tokenizer):
else:
tag = self._extract_string(
end,
unescape_sequences=False,
raw_string=True,
raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER,
)
@ -1297,7 +1315,7 @@ class Tokenizer(metaclass=_Tokenizer):
return False
self._advance(len(start))
text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING)
text = self._extract_string(end, raw_string=token_type == TokenType.RAW_STRING)
if base:
try:
@ -1333,7 +1351,7 @@ class Tokenizer(metaclass=_Tokenizer):
self,
delimiter: str,
escapes: t.Optional[t.Set[str]] = None,
unescape_sequences: bool = True,
raw_string: bool = False,
raise_unmatched: bool = True,
) -> str:
text = ""
@ -1342,7 +1360,7 @@ class Tokenizer(metaclass=_Tokenizer):
while True:
if (
unescape_sequences
not raw_string
and self.dialect.UNESCAPED_SEQUENCES
and self._peek
and self._char in self.STRING_ESCAPES
@ -1353,7 +1371,8 @@ class Tokenizer(metaclass=_Tokenizer):
text += unescaped_sequence
continue
if (
self._char in escapes
(self.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS or not raw_string)
and self._char in escapes
and (self._peek == delimiter or self._peek in escapes)
and (self._char not in self._QUOTES or self._char == self._peek)
):

View file

@ -9,6 +9,52 @@ if t.TYPE_CHECKING:
from sqlglot.generator import Generator
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
"""
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
return transforms_handler(self, expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
return _to_sql
def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
Replace references to select aliases in GROUP BY clauses.
@ -393,7 +439,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
for cte in expression.expressions:
if not cte.args["alias"].columns:
query = cte.this
if isinstance(query, exp.Union):
if isinstance(query, exp.SetOperation):
query = query.this
cte.args["alias"].set(
@ -623,47 +669,103 @@ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
return expression
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
For example,
SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to
SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Args:
transforms: sequence of transform functions. These will be called in order.
expression: The AST to remove join marks from.
Returns:
Function that can be used as a generator transform.
The AST with join marks removed.
"""
from sqlglot.optimizer.scope import traverse_scope
def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
for scope in traverse_scope(expression):
query = scope.expression
expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)
where = query.args.get("where")
joins = query.args.get("joins")
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
if not where or not joins:
continue
transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
query_from = query.args["from"]
# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
# These keep track of the joins to be replaced
new_joins: t.Dict[str, exp.Join] = {}
old_joins = {join.alias_or_name: join for join in joins}
return transforms_handler(self, expression)
for column in scope.columns:
if not column.args.get("join_mark"):
continue
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
predicate = column.find_ancestor(exp.Predicate, exp.Select)
assert isinstance(
predicate, exp.Binary
), "Columns can only be marked with (+) when involved in a binary operation"
return _to_sql
predicate_parent = predicate.parent
join_predicate = predicate.pop()
left_columns = [
c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
]
right_columns = [
c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
]
assert not (
left_columns and right_columns
), "The (+) marker cannot appear in both sides of a binary predicate"
marked_column_tables = set()
for col in left_columns or right_columns:
table = col.table
assert table, f"Column {col} needs to be qualified with a table"
col.set("join_mark", False)
marked_column_tables.add(table)
assert (
len(marked_column_tables) == 1
), "Columns of only a single table can be marked with (+) in a given binary predicate"
join_this = old_joins.get(col.table, query_from).this
new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
# Upsert new_join into new_joins dictionary
new_join_alias_or_name = new_join.alias_or_name
existing_join = new_joins.get(new_join_alias_or_name)
if existing_join:
existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
else:
new_joins[new_join_alias_or_name] = new_join
# If the parent of the target predicate is a binary node, then it now has only one child
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
else:
predicate_parent.replace(predicate_parent.left)
if query_from.alias_or_name in new_joins:
only_old_joins = old_joins.keys() - new_joins.keys()
assert (
len(only_old_joins) >= 1
), "Cannot determine which table to use in the new FROM clause"
new_from_name = list(only_old_joins)[0]
query.set("from", exp.From(this=old_joins[new_from_name].this))
query.set("joins", list(new_joins.values()))
if not where.this:
where.pop()
return expression