Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3759c601a7
commit
96b10de29a
115 changed files with 66603 additions and 60920 deletions
|
@ -1,3 +1,4 @@
|
|||
# ruff: noqa: F401
|
||||
"""
|
||||
.. include:: ../README.md
|
||||
|
||||
|
@ -87,11 +88,13 @@ def parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
|
||||
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def parse_one(sql: str, **opts) -> Expression: ...
|
||||
def parse_one(sql: str, **opts) -> Expression:
|
||||
...
|
||||
|
||||
|
||||
def parse_one(
|
||||
|
|
|
@ -13,4 +13,5 @@ if t.TYPE_CHECKING:
|
|||
A = t.TypeVar("A", bound=t.Any)
|
||||
B = t.TypeVar("B", bound="sqlglot.exp.Binary")
|
||||
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
|
||||
F = t.TypeVar("F", bound="sqlglot.exp.Func")
|
||||
T = t.TypeVar("T")
|
||||
|
|
|
@ -140,10 +140,12 @@ class DataFrame:
|
|||
return cte, name
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
|
||||
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
|
||||
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
||||
...
|
||||
|
||||
def _ensure_list_of_columns(self, cols):
|
||||
return Column.ensure_cols(ensure_list(cols))
|
||||
|
|
|
@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
|
||||
this = Column.invoke_expression_over_column(col, expression.First)
|
||||
if ignorenulls:
|
||||
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
|
||||
return this
|
||||
|
||||
|
||||
def grouping_id(*cols: ColumnOrName) -> Column:
|
||||
|
@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
|
||||
return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
|
||||
this = Column.invoke_expression_over_column(col, expression.Last)
|
||||
if ignorenulls:
|
||||
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
|
||||
return this
|
||||
|
||||
|
||||
def monotonically_increasing_id() -> Column:
|
||||
|
@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column:
|
|||
def lag(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
|
||||
) -> Column:
|
||||
if default is not None:
|
||||
return Column.invoke_anonymous_function(col, "LAG", offset, default)
|
||||
if offset != 1:
|
||||
return Column.invoke_anonymous_function(col, "LAG", offset)
|
||||
return Column.invoke_anonymous_function(col, "LAG")
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.Lag, offset=None if offset == 1 else offset, default=default
|
||||
)
|
||||
|
||||
|
||||
def lead(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
|
||||
) -> Column:
|
||||
if default is not None:
|
||||
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
|
||||
if offset != 1:
|
||||
return Column.invoke_anonymous_function(col, "LEAD", offset)
|
||||
return Column.invoke_anonymous_function(col, "LEAD")
|
||||
return Column.invoke_expression_over_column(
|
||||
col, expression.Lead, offset=None if offset == 1 else offset, default=default
|
||||
)
|
||||
|
||||
|
||||
def nth_value(
|
||||
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
|
||||
) -> Column:
|
||||
this = Column.invoke_expression_over_column(
|
||||
col, expression.NthValue, offset=None if offset == 1 else offset
|
||||
)
|
||||
if ignoreNulls is not None:
|
||||
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
|
||||
if offset != 1:
|
||||
return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
|
||||
return Column.invoke_anonymous_function(col, "NTH_VALUE")
|
||||
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
|
||||
return this
|
||||
|
||||
|
||||
def ntile(n: int) -> Column:
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# ruff: noqa: F401
|
||||
"""
|
||||
## Dialects
|
||||
|
||||
While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult
|
||||
to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible
|
||||
SQL transpilation framework.
|
||||
SQL transpilation framework.
|
||||
|
||||
The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible.
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
|
|||
min_or_least,
|
||||
no_ilike_sql,
|
||||
parse_date_delta_with_interval,
|
||||
path_to_jsonpath,
|
||||
regexp_replace_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
|
@ -458,8 +457,10 @@ class BigQuery(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False
|
||||
) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
if isinstance(table.this, exp.Identifier) and "." in table.name:
|
||||
catalog, db, this, *rest = (
|
||||
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
|
||||
|
@ -474,10 +475,12 @@ class BigQuery(Dialect):
|
|||
return table
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
json_object = super()._parse_json_object()
|
||||
|
@ -536,6 +539,8 @@ class BigQuery(Dialect):
|
|||
UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
IGNORE_NULLS_IN_FUNC = True
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -554,7 +559,8 @@ class BigQuery(Dialect):
|
|||
exp.Create: _create_sql,
|
||||
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
|
||||
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateDiff: lambda self,
|
||||
e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateFromParts: rename_func("DATE"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
|
||||
|
@ -565,7 +571,6 @@ class BigQuery(Dialect):
|
|||
"DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
|
||||
),
|
||||
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql(false_value="NULL"),
|
||||
|
@ -597,12 +602,13 @@ class BigQuery(Dialect):
|
|||
]
|
||||
),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.StabilityProperty: lambda self, e: (
|
||||
f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
|
||||
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.StrToDate: lambda self,
|
||||
e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.StrToTime: lambda self, e: self.func(
|
||||
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
|
||||
),
|
||||
|
@ -610,9 +616,10 @@ class BigQuery(Dialect):
|
|||
exp.TimeFromParts: rename_func("TIME"),
|
||||
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
|
||||
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToTime: rename_func("TIME"),
|
||||
|
@ -623,6 +630,12 @@ class BigQuery(Dialect):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
|
||||
|
|
|
@ -8,12 +8,15 @@ from sqlglot.dialects.dialect import (
|
|||
arg_max_or_min_no_count,
|
||||
date_delta_sql,
|
||||
inline_array_sql,
|
||||
json_extract_segments,
|
||||
json_path_key_only_name,
|
||||
no_pivot_sql,
|
||||
parse_json_extract_path,
|
||||
rename_func,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
|
@ -120,6 +123,9 @@ class ClickHouse(Dialect):
|
|||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"JSONEXTRACTSTRING": parse_json_extract_path(
|
||||
exp.JSONExtractScalar, zero_based_indexing=False
|
||||
),
|
||||
"MAP": parse_var_map,
|
||||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"RANDCANONICAL": exp.Rand.from_arg_list,
|
||||
|
@ -354,9 +360,14 @@ class ClickHouse(Dialect):
|
|||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table(
|
||||
schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
|
||||
schema=schema,
|
||||
joins=joins,
|
||||
alias_tokens=alias_tokens,
|
||||
parse_bracket=parse_bracket,
|
||||
is_db_reference=is_db_reference,
|
||||
)
|
||||
|
||||
if self._match(TokenType.FINAL):
|
||||
|
@ -518,6 +529,12 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.VARCHAR: "String",
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**STRING_TYPE_MAPPING,
|
||||
|
@ -570,6 +587,10 @@ class ClickHouse(Dialect):
|
|||
exp.Explode: rename_func("arrayJoin"),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.IsNan: rename_func("isNaN"),
|
||||
exp.JSONExtract: json_extract_segments("JSONExtractString", quoted_index=False),
|
||||
exp.JSONExtractScalar: json_extract_segments("JSONExtractString", quoted_index=False),
|
||||
exp.JSONPathKey: json_path_key_only_name,
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Nullif: rename_func("nullIf"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
@ -579,7 +600,8 @@ class ClickHouse(Dialect):
|
|||
exp.Rand: rename_func("randCanonical"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
|
||||
exp.StartsWith: rename_func("startsWith"),
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.StrPosition: lambda self,
|
||||
e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
|
||||
}
|
||||
|
@ -608,6 +630,13 @@ class ClickHouse(Dialect):
|
|||
"NAMED COLLECTION",
|
||||
}
|
||||
|
||||
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
|
||||
this = self.json_path_part(expression.this)
|
||||
return str(int(this) + 1) if is_int(this) else this
|
||||
|
||||
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
|
||||
return f"AS {self.sql(expression, 'this')}"
|
||||
|
||||
def _any_to_has(
|
||||
self,
|
||||
expression: exp.EQ | exp.NEQ,
|
||||
|
|
|
@ -22,6 +22,7 @@ class Databricks(Spark):
|
|||
"DATEADD": parse_date_delta(exp.DateAdd),
|
||||
"DATE_ADD": parse_date_delta(exp.DateAdd),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff),
|
||||
"TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
|
||||
}
|
||||
|
||||
FACTOR = {
|
||||
|
@ -48,6 +49,9 @@ class Databricks(Spark):
|
|||
exp.DatetimeDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.DatetimeTrunc: timestamptrunc_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.Select: transforms.preprocess(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
from enum import Enum, auto
|
||||
from functools import reduce
|
||||
|
@ -7,7 +8,8 @@ from functools import reduce
|
|||
from sqlglot import exp
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import AutoName, flatten, seq_get
|
||||
from sqlglot.helper import AutoName, flatten, is_int, seq_get
|
||||
from sqlglot.jsonpath import parse as parse_json_path
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import TIMEZONES, format_time
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
|
|||
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import B, E
|
||||
from sqlglot._typing import B, E, F
|
||||
|
||||
JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
|
@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
# Delimiters for string literals and identifiers
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
|
@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect):
|
|||
"""
|
||||
if (
|
||||
isinstance(expression, exp.Identifier)
|
||||
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
|
||||
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
|
||||
and (
|
||||
not expression.quoted
|
||||
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect):
|
|||
|
||||
return expression
|
||||
|
||||
def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if isinstance(path, exp.Literal):
|
||||
path_text = path.name
|
||||
if path.is_number:
|
||||
path_text = f"[{path_text}]"
|
||||
|
||||
try:
|
||||
return parse_json_path(path_text)
|
||||
except ParseError as e:
|
||||
logger.warning(f"Invalid JSON path syntax. {str(e)}")
|
||||
|
||||
return path
|
||||
|
||||
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse(self.tokenize(sql), sql)
|
||||
|
||||
|
@ -500,14 +519,12 @@ def if_sql(
|
|||
return _if_sql
|
||||
|
||||
|
||||
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||
return self.binary(expression, "->")
|
||||
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
|
||||
this = expression.this
|
||||
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
|
||||
this.replace(exp.cast(this, "json"))
|
||||
|
||||
|
||||
def arrow_json_extract_scalar_sql(
|
||||
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
|
||||
) -> str:
|
||||
return self.binary(expression, "->>")
|
||||
return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
|
||||
|
||||
|
||||
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
||||
|
@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
|||
return self.cast_sql(expression)
|
||||
|
||||
|
||||
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
||||
self.unsupported("Properties unsupported")
|
||||
return ""
|
||||
|
||||
|
||||
def no_comment_column_constraint_sql(
|
||||
self: Generator, expression: exp.CommentColumnConstraint
|
||||
) -> str:
|
||||
|
@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
|
|||
return _delta_sql
|
||||
|
||||
|
||||
def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
# Makes sure the path will be evaluated correctly at runtime to include the path root.
|
||||
# For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
|
||||
path = expression.expression
|
||||
path = exp.func(
|
||||
"if",
|
||||
exp.func("startswith", path, "'['"),
|
||||
exp.func("concat", "'$'", path),
|
||||
exp.func("concat", "'$.'", path),
|
||||
)
|
||||
|
||||
expression.expression.replace(simplify(path))
|
||||
return expression
|
||||
|
||||
|
||||
def path_to_jsonpath(
|
||||
name: str = "JSON_EXTRACT",
|
||||
) -> t.Callable[[Generator, exp.GetPath], str]:
|
||||
def _transform(self: Generator, expression: exp.GetPath) -> str:
|
||||
return rename_func(name)(self, prepend_dollar_to_path(expression))
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
|
||||
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
|
||||
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
|
||||
|
@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
"""Remove table refs from columns in when statements."""
|
||||
alias = expression.this.args.get("alias")
|
||||
|
||||
normalize = lambda identifier: (
|
||||
self.dialect.normalize_identifier(identifier).name if identifier else None
|
||||
)
|
||||
def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
|
||||
return self.dialect.normalize_identifier(identifier).name if identifier else None
|
||||
|
||||
targets = {normalize(expression.this.this)}
|
||||
|
||||
|
@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
|
|||
)
|
||||
|
||||
return self.merge_sql(expression)
|
||||
|
||||
|
||||
def parse_json_extract_path(
|
||||
expr_type: t.Type[F], zero_based_indexing: bool = True
|
||||
) -> t.Callable[[t.List], F]:
|
||||
def _parse_json_extract_path(args: t.List) -> F:
|
||||
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
|
||||
for arg in args[1:]:
|
||||
if not isinstance(arg, exp.Literal):
|
||||
# We use the fallback parser because we can't really transpile non-literals safely
|
||||
return expr_type.from_arg_list(args)
|
||||
|
||||
text = arg.name
|
||||
if is_int(text):
|
||||
index = int(text)
|
||||
segments.append(
|
||||
exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
|
||||
)
|
||||
else:
|
||||
segments.append(exp.JSONPathKey(this=text))
|
||||
|
||||
# This is done to avoid failing in the expression validator due to the arg count
|
||||
del args[2:]
|
||||
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
|
||||
|
||||
return _parse_json_extract_path
|
||||
|
||||
|
||||
def json_extract_segments(
|
||||
name: str, quoted_index: bool = True
|
||||
) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
|
||||
def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
|
||||
path = expression.expression
|
||||
if not isinstance(path, exp.JSONPath):
|
||||
return rename_func(name)(self, expression)
|
||||
|
||||
segments = []
|
||||
for segment in path.expressions:
|
||||
path = self.sql(segment)
|
||||
if path:
|
||||
if isinstance(segment, exp.JSONPathPart) and (
|
||||
quoted_index or not isinstance(segment, exp.JSONPathSubscript)
|
||||
):
|
||||
path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
|
||||
|
||||
segments.append(path)
|
||||
|
||||
return self.func(name, expression.this, *segments)
|
||||
|
||||
return _json_extract_segments
|
||||
|
||||
|
||||
def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
|
||||
if isinstance(expression.this, exp.JSONPathWildcard):
|
||||
self.unsupported("Unsupported wildcard in JSONPathKey expression")
|
||||
|
||||
return expression.name
|
||||
|
|
|
@ -55,11 +55,14 @@ class Doris(MySQL):
|
|||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToUnix: lambda self,
|
||||
e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
|
||||
exp.ToChar: lambda self,
|
||||
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TsOrDsAdd: lambda self,
|
||||
e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
|
||||
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
|
|
|
@ -99,6 +99,7 @@ class Drill(Dialect):
|
|||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -128,10 +129,14 @@ class Drill(Dialect):
|
|||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
|
||||
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self,
|
||||
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
|
||||
exp.If: lambda self,
|
||||
e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
|
||||
exp.ILike: lambda self,
|
||||
e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
|
@ -141,7 +146,8 @@ class Drill(Dialect):
|
|||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
@ -149,8 +155,10 @@ class Drill(Dialect):
|
|||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: lambda self,
|
||||
e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
||||
def normalize_func(self, name: str) -> str:
|
||||
|
|
|
@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
|
|||
NormalizationStrategy,
|
||||
approx_count_distinct_sql,
|
||||
arg_max_or_min_no_count,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
binary_from_function,
|
||||
bool_xor_sql,
|
||||
|
@ -18,11 +17,9 @@ from sqlglot.dialects.dialect import (
|
|||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
no_comment_column_constraint_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
no_timestamp_sql,
|
||||
pivot_column_names,
|
||||
prepend_dollar_to_path,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
|
@ -172,6 +169,18 @@ class DuckDB(Dialect):
|
|||
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if isinstance(path, exp.Literal):
|
||||
# DuckDB also supports the JSON pointer syntax, where every path starts with a `/`.
|
||||
# Additionally, it allows accessing the back of lists using the `[#-i]` syntax.
|
||||
# This check ensures we'll avoid trying to parse these as JSON paths, which can
|
||||
# either result in a noisy warning or in an invalid representation of the path.
|
||||
path_text = path.name
|
||||
if path_text.startswith("/") or "[#" in path_text:
|
||||
return path
|
||||
|
||||
return super().to_json_path(path)
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -229,6 +238,8 @@ class DuckDB(Dialect):
|
|||
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
|
||||
),
|
||||
"JSON": exp.ParseJSON.from_arg_list,
|
||||
"JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract),
|
||||
"JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
|
||||
"LIST_HAS": exp.ArrayContains.from_arg_list,
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
|
@ -319,6 +330,9 @@ class DuckDB(Dialect):
|
|||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
IGNORE_NULLS_IN_FUNC = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -350,18 +364,18 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.DiToDate: lambda self,
|
||||
e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.IsInf: rename_func("ISINF"),
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
|
@ -377,7 +391,6 @@ class DuckDB(Dialect):
|
|||
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
|
||||
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
|
||||
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
exp.RegexpReplace: lambda self, e: self.func(
|
||||
"REGEXP_REPLACE",
|
||||
|
@ -395,7 +408,8 @@ class DuckDB(Dialect):
|
|||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
|
||||
exp.StrToTime: str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StrToUnix: lambda self,
|
||||
e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: _struct_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
|
@ -405,9 +419,11 @@ class DuckDB(Dialect):
|
|||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
|
||||
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToStr: lambda self,
|
||||
e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF",
|
||||
|
@ -415,7 +431,8 @@ class DuckDB(Dialect):
|
|||
exp.cast(e.expression, "TIMESTAMP"),
|
||||
exp.cast(e.this, "TIMESTAMP"),
|
||||
),
|
||||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToStr: lambda self,
|
||||
e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
|
@ -423,6 +440,13 @@ class DuckDB(Dialect):
|
|||
exp.Xor: bool_xor_sql,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
exp.JSONPathWildcard,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "BLOB",
|
||||
|
@ -442,11 +466,18 @@ class DuckDB(Dialect):
|
|||
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
|
||||
|
||||
# DuckDB doesn't generally support CREATE TABLE .. properties
|
||||
# https://duckdb.org/docs/sql/statements/create_table.html
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
prop: exp.Properties.Location.UNSUPPORTED
|
||||
for prop in generator.Generator.PROPERTIES_LOCATION
|
||||
}
|
||||
|
||||
# There are a few exceptions (e.g. temporary tables) which are supported or
|
||||
# can be transpiled to DuckDB, so we explicitly override them accordingly
|
||||
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
|
||||
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
|
||||
|
||||
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
|
||||
nano = expression.args.get("nano")
|
||||
if nano is not None:
|
||||
|
@ -486,10 +517,6 @@ class DuckDB(Dialect):
|
|||
expression, sep=sep, tablesample_keyword=tablesample_keyword
|
||||
)
|
||||
|
||||
def getpath_sql(self, expression: exp.GetPath) -> str:
|
||||
expression = prepend_dollar_to_path(expression)
|
||||
return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
multiplier: t.Optional[int] = None
|
||||
unit = expression.text("unit").lower()
|
||||
|
|
|
@ -192,6 +192,18 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _parse_ignore_nulls(
|
||||
exp_class: t.Type[exp.Expression],
|
||||
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
|
||||
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
|
||||
this = exp_class(this=seq_get(args, 0))
|
||||
if seq_get(args, 1) == exp.true():
|
||||
return exp.IgnoreNulls(this=this)
|
||||
return this
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
class Hive(Dialect):
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
@ -298,8 +310,12 @@ class Hive(Dialect):
|
|||
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
|
||||
),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"FIRST": _parse_ignore_nulls(exp.First),
|
||||
"FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue),
|
||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
"LAST": _parse_ignore_nulls(exp.Last),
|
||||
"LAST_VALUE": _parse_ignore_nulls(exp.LastValue),
|
||||
"LOCATE": locate_to_strposition,
|
||||
"MAP": parse_var_map,
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
|
||||
|
@ -429,6 +445,7 @@ class Hive(Dialect):
|
|||
EXTRACT_ALLOWS_QUOTES = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Insert,
|
||||
|
@ -437,6 +454,13 @@ class Hive(Dialect):
|
|||
exp.Union,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
exp.JSONPathWildcard,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BIT: "BOOLEAN",
|
||||
|
@ -471,9 +495,12 @@ class Hive(Dialect):
|
|||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _add_date_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self,
|
||||
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
|
||||
exp.FileFormatProperty: lambda self,
|
||||
e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
@ -502,7 +529,8 @@ class Hive(Dialect):
|
|||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.Split: lambda self,
|
||||
e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
|
@ -514,7 +542,8 @@ class Hive(Dialect):
|
|||
exp.TimeToStr: _time_to_str,
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToBase64: rename_func("BASE64"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _add_date_sql,
|
||||
exp.TsOrDsDiff: _date_diff_sql,
|
||||
exp.TsOrDsToDate: _to_date_sql,
|
||||
|
@ -528,8 +557,10 @@ class Hive(Dialect):
|
|||
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
|
||||
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
|
||||
exp.National: lambda self, e: self.national_sql(e, prefix=""),
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.ClusteredColumnConstraint: lambda self,
|
||||
e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NonClusteredColumnConstraint: lambda self,
|
||||
e: f"({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NotForReplicationColumnConstraint: lambda self, e: "",
|
||||
exp.OnProperty: lambda self, e: "",
|
||||
exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY",
|
||||
|
@ -543,6 +574,13 @@ class Hive(Dialect):
|
|||
exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
|
||||
if isinstance(expression.this, exp.JSONPathWildcard):
|
||||
self.unsupported("Unsupported wildcard in JSONPathKey expression")
|
||||
return ""
|
||||
|
||||
return super()._jsonpathkey_sql(expression)
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# Hive has no temporary storage provider (there are hive settings though)
|
||||
return expression
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
NormalizationStrategy,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
date_add_interval_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
|
@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_date_delta,
|
||||
parse_date_delta_with_interval,
|
||||
path_to_jsonpath,
|
||||
rename_func,
|
||||
strposition_to_locate_sql,
|
||||
)
|
||||
|
@ -306,6 +306,7 @@ class MySQL(Dialect):
|
|||
format=exp.Literal.string("%B"),
|
||||
),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
"TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
|
||||
"TO_DAYS": lambda args: exp.paren(
|
||||
exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
|
@ -357,6 +358,7 @@ class MySQL(Dialect):
|
|||
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
|
||||
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
|
||||
"DATABASES": _show_parser("DATABASES"),
|
||||
"SCHEMAS": _show_parser("DATABASES"),
|
||||
"ENGINE": _show_parser("ENGINE", target=True),
|
||||
"STORAGE ENGINES": _show_parser("ENGINES"),
|
||||
"ENGINES": _show_parser("ENGINES"),
|
||||
|
@ -630,6 +632,8 @@ class MySQL(Dialect):
|
|||
VALUES_AS_TABLE = False
|
||||
NVL2_SUPPORTED = False
|
||||
LAST_DAY_SUPPORTS_DATE_PART = False
|
||||
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
JSON_KEY_VALUE_PAIR_SEP = ","
|
||||
|
||||
TRANSFORMS = {
|
||||
|
@ -646,10 +650,10 @@ class MySQL(Dialect):
|
|||
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
|
||||
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
|
||||
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.GroupConcat: lambda self,
|
||||
e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Month: _remove_ts_or_ds_to_date(),
|
||||
|
@ -672,6 +676,9 @@ class MySQL(Dialect):
|
|||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeFromParts: rename_func("MAKETIME"),
|
||||
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
|
||||
exp.TimestampDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
|
||||
|
|
|
@ -199,7 +199,8 @@ class Oracle(Dialect):
|
|||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
|
@ -208,7 +209,8 @@ class Oracle(Dialect):
|
|||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: trim_sql,
|
||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
exp.UnixToTime: lambda self,
|
||||
e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -7,11 +7,11 @@ from sqlglot.dialects.dialect import (
|
|||
DATE_ADD_OR_SUB,
|
||||
Dialect,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
bool_xor_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
json_extract_segments,
|
||||
json_path_key_only_name,
|
||||
max_or_greatest,
|
||||
merge_without_target_sql,
|
||||
min_or_least,
|
||||
|
@ -20,6 +20,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_paren_current_date_sql,
|
||||
no_pivot_sql,
|
||||
no_trycast_sql,
|
||||
parse_json_extract_path,
|
||||
parse_timestamp_trunc,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
|
@ -292,6 +293,8 @@ class Postgres(Dialect):
|
|||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
|
||||
"JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
|
||||
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
|
||||
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
|
@ -375,8 +378,15 @@ class Postgres(Dialect):
|
|||
TABLESAMPLE_SIZE_IS_ROWS = False
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
# https://www.postgresql.org/docs/current/sql-createtable.html
|
||||
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
|
||||
SUPPORTS_UNLOGGED_TABLES = True
|
||||
LIKE_PROPERTY_INSIDE_SCHEMA = True
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -412,11 +422,14 @@ class Postgres(Dialect):
|
|||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"),
|
||||
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
|
||||
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
|
||||
exp.JSONPathKey: json_path_key_only_name,
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this),
|
||||
exp.LastDay: no_last_day_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
|
@ -443,7 +456,8 @@ class Postgres(Dialect):
|
|||
]
|
||||
),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Substring: _substring_sql,
|
||||
exp.TimeFromParts: rename_func("MAKE_TIME"),
|
||||
|
|
|
@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
no_safe_divide_sql,
|
||||
no_timestamp_sql,
|
||||
path_to_jsonpath,
|
||||
regexp_extract_sql,
|
||||
rename_func,
|
||||
right_to_substring_sql,
|
||||
|
@ -150,7 +149,7 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
|
||||
def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str:
|
||||
"""
|
||||
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
|
||||
of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
|
||||
|
@ -292,6 +291,7 @@ class Presto(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
LIMIT_ONLY_LITERALS = True
|
||||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
LIKE_PROPERTY_INSIDE_SCHEMA = True
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -324,12 +324,18 @@ class Presto(Dialect):
|
|||
exp.ArrayContains: rename_func("CONTAINS"),
|
||||
exp.ArraySize: rename_func("CARDINALITY"),
|
||||
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
|
||||
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
|
||||
exp.BitwiseAnd: lambda self,
|
||||
e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseLeftShift: lambda self,
|
||||
e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
|
||||
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseOr: lambda self,
|
||||
e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseRightShift: lambda self,
|
||||
e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.BitwiseXor: lambda self,
|
||||
e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
|
@ -344,7 +350,8 @@ class Presto(Dialect):
|
|||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateToDi: lambda self,
|
||||
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "DAY"),
|
||||
|
@ -352,12 +359,14 @@ class Presto(Dialect):
|
|||
e.this,
|
||||
),
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.DiToDate: lambda self,
|
||||
e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.First: _first_last_sql,
|
||||
exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
|
||||
exp.GetPath: path_to_jsonpath(),
|
||||
exp.FirstValue: _first_last_sql,
|
||||
exp.FromTimeZone: lambda self,
|
||||
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
|
@ -368,6 +377,7 @@ class Presto(Dialect):
|
|||
exp.Initcap: _initcap_sql,
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Last: _first_last_sql,
|
||||
exp.LastValue: _first_last_sql,
|
||||
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
|
@ -394,26 +404,33 @@ class Presto(Dialect):
|
|||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StrToUnix: lambda self,
|
||||
e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
|
||||
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToUnix: lambda self,
|
||||
e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
|
||||
exp.TimeToStr: lambda self,
|
||||
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
|
||||
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.ToChar: lambda self,
|
||||
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDiToDi: lambda self,
|
||||
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToStr: lambda self,
|
||||
e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.UnixToTimeStr: lambda self,
|
||||
e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
|
||||
exp.WithinGroup: transforms.preprocess(
|
||||
|
|
|
@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
|
|||
concat_ws_to_dpipe_sql,
|
||||
date_delta_sql,
|
||||
generatedasidentitycolumnconstraint_sql,
|
||||
json_extract_segments,
|
||||
no_tablesample_sql,
|
||||
rename_func,
|
||||
)
|
||||
|
@ -20,10 +21,6 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
|
||||
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
|
||||
|
||||
|
||||
def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
|
||||
def _parse_delta(args: t.List) -> E:
|
||||
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
@ -62,6 +59,7 @@ class Redshift(Postgres):
|
|||
"DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
|
||||
"DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
|
||||
"DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
|
||||
"GETDATE": exp.CurrentTimestamp.from_arg_list,
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"STRTOL": exp.FromBase.from_arg_list,
|
||||
}
|
||||
|
@ -69,6 +67,7 @@ class Redshift(Postgres):
|
|||
NO_PAREN_FUNCTION_PARSERS = {
|
||||
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
|
||||
"APPROXIMATE": lambda self: self._parse_approximate_count(),
|
||||
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
|
||||
}
|
||||
|
||||
def _parse_table(
|
||||
|
@ -77,6 +76,7 @@ class Redshift(Postgres):
|
|||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
|
||||
unpivot = self._match(TokenType.UNPIVOT)
|
||||
|
@ -85,6 +85,7 @@ class Redshift(Postgres):
|
|||
joins=joins,
|
||||
alias_tokens=alias_tokens,
|
||||
parse_bracket=parse_bracket,
|
||||
is_db_reference=is_db_reference,
|
||||
)
|
||||
|
||||
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
|
||||
|
@ -153,7 +154,6 @@ class Redshift(Postgres):
|
|||
**Postgres.Tokenizer.KEYWORDS,
|
||||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
|
@ -180,31 +180,29 @@ class Redshift(Postgres):
|
|||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Postgres.Generator.PROPERTIES_LOCATION,
|
||||
exp.LikeProperty: exp.Properties.Location.POST_WITH,
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS,
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.ConcatWs: concat_ws_to_dpipe_sql,
|
||||
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.ApproxDistinct: lambda self,
|
||||
e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
|
||||
exp.CurrentTimestamp: lambda self, e: (
|
||||
"SYSDATE" if e.args.get("transaction") else "GETDATE()"
|
||||
),
|
||||
exp.DateAdd: date_delta_sql("DATEADD"),
|
||||
exp.DateDiff: date_delta_sql("DATEDIFF"),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.FromBase: rename_func("STRTOL"),
|
||||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.ParseJSON: rename_func("JSON_PARSE"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.SortKeyProperty: lambda self,
|
||||
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
|
||||
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
|
||||
|
@ -228,6 +226,13 @@ class Redshift(Postgres):
|
|||
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
|
||||
return self.properties(properties, prefix=" ", suffix="")
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if expression.is_type(exp.DataType.Type.JSON):
|
||||
# Redshift doesn't support a JSON type, so casting to it is treated as a noop
|
||||
return self.sql(expression, "this")
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
"""
|
||||
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
|
||||
|
|
|
@ -21,19 +21,13 @@ from sqlglot.dialects.dialect import (
|
|||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import is_int, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot._typing import E
|
||||
|
||||
|
||||
def _check_int(s: str) -> bool:
|
||||
if s[0] in ("-", "+"):
|
||||
return s[1:].isdigit()
|
||||
return s.isdigit()
|
||||
|
||||
|
||||
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
|
||||
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
|
||||
if len(args) == 2:
|
||||
|
@ -53,7 +47,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
|
|||
return exp.TimeStrToTime.from_arg_list(args)
|
||||
|
||||
if first_arg.is_string:
|
||||
if _check_int(first_arg.this):
|
||||
if is_int(first_arg.this):
|
||||
# case: <integer>
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
@ -241,7 +235,6 @@ DATE_PART_MAPPING = {
|
|||
"NSECOND": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"NANOSECS": "NANOSECOND",
|
||||
"NSECONDS": "NANOSECOND",
|
||||
"EPOCH": "EPOCH_SECOND",
|
||||
"EPOCH_SECONDS": "EPOCH_SECOND",
|
||||
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
|
||||
|
@ -291,7 +284,9 @@ def _parse_colon_get_path(
|
|||
path = exp.Literal.string(path.sql(dialect="snowflake"))
|
||||
|
||||
# The extraction operator : is left-associative
|
||||
this = self.expression(exp.GetPath, this=this, expression=path)
|
||||
this = self.expression(
|
||||
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
|
||||
)
|
||||
|
||||
if target_type:
|
||||
this = exp.cast(this, target_type)
|
||||
|
@ -411,6 +406,9 @@ class Snowflake(Dialect):
|
|||
"DATEDIFF": _parse_datediff,
|
||||
"DIV0": _div0_to_if,
|
||||
"FLATTEN": exp.Explode.from_arg_list,
|
||||
"GET_PATH": lambda args, dialect: exp.JSONExtract(
|
||||
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
|
||||
),
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"LAST_DAY": lambda args: exp.LastDay(
|
||||
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
|
||||
|
@ -474,6 +472,8 @@ class Snowflake(Dialect):
|
|||
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
|
||||
"OBJECTS": _show_parser("OBJECTS"),
|
||||
"TERSE OBJECTS": _show_parser("OBJECTS"),
|
||||
"TABLES": _show_parser("TABLES"),
|
||||
"TERSE TABLES": _show_parser("TABLES"),
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"COLUMNS": _show_parser("COLUMNS"),
|
||||
|
@ -534,7 +534,9 @@ class Snowflake(Dialect):
|
|||
|
||||
return table
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
def _parse_table_parts(
|
||||
self, schema: bool = False, is_db_reference: bool = False
|
||||
) -> exp.Table:
|
||||
# https://docs.snowflake.com/en/user-guide/querying-stage
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
table = self._parse_string()
|
||||
|
@ -550,7 +552,9 @@ class Snowflake(Dialect):
|
|||
self._match(TokenType.L_PAREN)
|
||||
while self._curr and not self._match(TokenType.R_PAREN):
|
||||
if self._match_text_seq("FILE_FORMAT", "=>"):
|
||||
file_format = self._parse_string() or super()._parse_table_parts()
|
||||
file_format = self._parse_string() or super()._parse_table_parts(
|
||||
is_db_reference=is_db_reference
|
||||
)
|
||||
elif self._match_text_seq("PATTERN", "=>"):
|
||||
pattern = self._parse_string()
|
||||
else:
|
||||
|
@ -560,7 +564,7 @@ class Snowflake(Dialect):
|
|||
|
||||
table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
|
||||
else:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
|
||||
return self._parse_at_before(table)
|
||||
|
||||
|
@ -587,6 +591,8 @@ class Snowflake(Dialect):
|
|||
# which is syntactically valid but has no effect on the output
|
||||
terse = self._tokens[self._index - 2].text.upper() == "TERSE"
|
||||
|
||||
history = self._match_text_seq("HISTORY")
|
||||
|
||||
like = self._parse_string() if self._match(TokenType.LIKE) else None
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
|
@ -597,7 +603,7 @@ class Snowflake(Dialect):
|
|||
if self._curr:
|
||||
scope = self._parse_table_parts()
|
||||
elif self._curr:
|
||||
scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
|
||||
scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE"
|
||||
scope = self._parse_table_parts()
|
||||
|
||||
return self.expression(
|
||||
|
@ -605,6 +611,7 @@ class Snowflake(Dialect):
|
|||
**{
|
||||
"terse": terse,
|
||||
"this": this,
|
||||
"history": history,
|
||||
"like": like,
|
||||
"scope": scope,
|
||||
"scope_kind": scope_kind,
|
||||
|
@ -715,8 +722,10 @@ class Snowflake(Dialect):
|
|||
),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.If: if_sql(name="IFF", false_value="NULL"),
|
||||
exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
|
||||
exp.JSONExtract: rename_func("GET_PATH"),
|
||||
exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
|
||||
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
|
||||
exp.JSONPathRoot: lambda *_: "",
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -745,7 +754,8 @@ class Snowflake(Dialect):
|
|||
exp.StrPosition: lambda self, e: self.func(
|
||||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Struct: lambda self, e: self.func(
|
||||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
|
@ -771,6 +781,12 @@ class Snowflake(Dialect):
|
|||
exp.Xor: rename_func("BOOLXOR"),
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
|
@ -841,6 +857,7 @@ class Snowflake(Dialect):
|
|||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
terse = "TERSE " if expression.args.get("terse") else ""
|
||||
history = " HISTORY" if expression.args.get("history") else ""
|
||||
like = self.sql(expression, "like")
|
||||
like = f" LIKE {like}" if like else ""
|
||||
|
||||
|
@ -861,9 +878,7 @@ class Snowflake(Dialect):
|
|||
if from_:
|
||||
from_ = f" FROM {from_}"
|
||||
|
||||
return (
|
||||
f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
|
||||
)
|
||||
return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
|
||||
|
||||
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
|
||||
# Other dialects don't support all of the following parameters, so we need to
|
||||
|
|
|
@ -4,6 +4,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.hive import _parse_ignore_nulls
|
||||
from sqlglot.dialects.spark2 import Spark2
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
@ -45,9 +46,7 @@ class Spark(Spark2):
|
|||
class Parser(Spark2.Parser):
|
||||
FUNCTIONS = {
|
||||
**Spark2.Parser.FUNCTIONS,
|
||||
"ANY_VALUE": lambda args: exp.AnyValue(
|
||||
this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
|
||||
),
|
||||
"ANY_VALUE": _parse_ignore_nulls(exp.AnyValue),
|
||||
"DATEDIFF": _parse_datediff,
|
||||
}
|
||||
|
||||
|
|
|
@ -187,8 +187,10 @@ class Spark2(Hive):
|
|||
TRANSFORMS = {
|
||||
**Hive.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.ArraySum: lambda self,
|
||||
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.AtTimeZone: lambda self,
|
||||
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
|
@ -198,7 +200,8 @@ class Spark2(Hive):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.FromTimeZone: lambda self,
|
||||
e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
|
@ -212,7 +215,8 @@ class Spark2(Hive):
|
|||
e.args.get("position"),
|
||||
),
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToTime: lambda self,
|
||||
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
|
|
|
@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
NormalizationStrategy,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
concat_to_dpipe_sql,
|
||||
count_if_to_sum,
|
||||
|
@ -28,6 +27,12 @@ def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
|
|||
return self.func("DATE", expression.this, modifier)
|
||||
|
||||
|
||||
def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> str:
|
||||
if expression.expressions:
|
||||
return self.function_fallback_sql(expression)
|
||||
return arrow_json_extract_sql(self, expression)
|
||||
|
||||
|
||||
def _transform_create(expression: exp.Expression) -> exp.Expression:
|
||||
"""Move primary key to a column and enforce auto_increment on primary keys."""
|
||||
schema = expression.this
|
||||
|
@ -85,6 +90,14 @@ class SQLite(Dialect):
|
|||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
SUPPORTS_CREATE_TABLE_LIKE = False
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -120,10 +133,8 @@ class SQLite(Dialect):
|
|||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONExtract: _json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
|
@ -141,11 +152,18 @@ class SQLite(Dialect):
|
|||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
# SQLite doesn't generally support CREATE TABLE .. properties
|
||||
# https://www.sqlite.org/lang_createtable.html
|
||||
PROPERTIES_LOCATION = {
|
||||
k: exp.Properties.Location.UNSUPPORTED
|
||||
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
|
||||
prop: exp.Properties.Location.UNSUPPORTED
|
||||
for prop in generator.Generator.PROPERTIES_LOCATION
|
||||
}
|
||||
|
||||
# There are a few exceptions (e.g. temporary tables) which are supported or
|
||||
# can be transpiled to SQLite, so we explicitly override them accordingly
|
||||
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
|
||||
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
|
||||
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
|
|
|
@ -44,12 +44,14 @@ class StarRocks(MySQL):
|
|||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StrToUnix: lambda self,
|
||||
e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToStr: lambda self,
|
||||
e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
}
|
||||
|
||||
|
|
|
@ -200,7 +200,8 @@ class Teradata(Dialect):
|
|||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.StrToDate: lambda self,
|
||||
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
}
|
||||
|
|
|
@ -11,9 +11,16 @@ class Trino(Presto):
|
|||
class Generator(Presto.Generator):
|
||||
TRANSFORMS = {
|
||||
**Presto.Generator.TRANSFORMS,
|
||||
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.ArraySum: lambda self,
|
||||
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||
exp.Merge: merge_without_target_sql,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
|
|
@ -14,7 +14,6 @@ from sqlglot.dialects.dialect import (
|
|||
max_or_greatest,
|
||||
min_or_least,
|
||||
parse_date_delta,
|
||||
path_to_jsonpath,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
|
@ -266,13 +265,32 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
|
|||
)
|
||||
|
||||
|
||||
def _parse_len(args: t.List) -> exp.Length:
|
||||
this = seq_get(args, 0)
|
||||
def _parse_as_text(
|
||||
klass: t.Type[exp.Expression],
|
||||
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
|
||||
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
|
||||
this = seq_get(args, 0)
|
||||
|
||||
if this and not this.is_string:
|
||||
this = exp.cast(this, exp.DataType.Type.TEXT)
|
||||
if this and not this.is_string:
|
||||
this = exp.cast(this, exp.DataType.Type.TEXT)
|
||||
|
||||
return exp.Length(this=this)
|
||||
expression = seq_get(args, 1)
|
||||
kwargs = {"this": this}
|
||||
|
||||
if expression:
|
||||
kwargs["expression"] = expression
|
||||
|
||||
return klass(**kwargs)
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
def _json_extract_sql(
|
||||
self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
|
||||
) -> str:
|
||||
json_query = rename_func("JSON_QUERY")(self, expression)
|
||||
json_value = rename_func("JSON_VALUE")(self, expression)
|
||||
return self.func("ISNULL", json_query, json_value)
|
||||
|
||||
|
||||
class TSQL(Dialect):
|
||||
|
@ -441,8 +459,11 @@ class TSQL(Dialect):
|
|||
"HASHBYTES": _parse_hashbytes,
|
||||
"IIF": exp.If.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
|
||||
"LEN": _parse_len,
|
||||
"JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
|
||||
"JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
|
||||
"LEN": _parse_as_text(exp.Length),
|
||||
"LEFT": _parse_as_text(exp.Left),
|
||||
"RIGHT": _parse_as_text(exp.Right),
|
||||
"REPLICATE": exp.Repeat.from_arg_list,
|
||||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
|
||||
|
@ -677,6 +698,7 @@ class TSQL(Dialect):
|
|||
SUPPORTS_SINGLE_ARG_CONCAT = False
|
||||
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
|
||||
SUPPORTS_SELECT_INTO = True
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
|
||||
|
||||
EXPRESSIONS_WITHOUT_NESTED_CTES = {
|
||||
exp.Delete,
|
||||
|
@ -688,6 +710,12 @@ class TSQL(Dialect):
|
|||
exp.Update,
|
||||
}
|
||||
|
||||
SUPPORTED_JSON_PATH_PARTS = {
|
||||
exp.JSONPathKey,
|
||||
exp.JSONPathRoot,
|
||||
exp.JSONPathSubscript,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BOOLEAN: "BIT",
|
||||
|
@ -712,9 +740,10 @@ class TSQL(Dialect):
|
|||
exp.CurrentTimestamp: rename_func("GETDATE"),
|
||||
exp.Extract: rename_func("DATEPART"),
|
||||
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
|
||||
exp.GetPath: path_to_jsonpath("JSON_VALUE"),
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.If: rename_func("IIF"),
|
||||
exp.JSONExtract: _json_extract_sql,
|
||||
exp.JSONExtractScalar: _json_extract_sql,
|
||||
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
|
||||
|
@ -831,15 +860,21 @@ class TSQL(Dialect):
|
|||
exists = expression.args.pop("exists", None)
|
||||
sql = super().create_sql(expression)
|
||||
|
||||
like_property = expression.find(exp.LikeProperty)
|
||||
if like_property:
|
||||
ctas_expression = like_property.this
|
||||
else:
|
||||
ctas_expression = expression.expression
|
||||
|
||||
table = expression.find(exp.Table)
|
||||
|
||||
# Convert CTAS statement to SELECT .. INTO ..
|
||||
if kind == "TABLE" and expression.expression:
|
||||
ctas_with = expression.expression.args.get("with")
|
||||
if kind == "TABLE" and ctas_expression:
|
||||
ctas_with = ctas_expression.args.get("with")
|
||||
if ctas_with:
|
||||
ctas_with = ctas_with.pop()
|
||||
|
||||
subquery = expression.expression
|
||||
subquery = ctas_expression
|
||||
if isinstance(subquery, exp.Subqueryable):
|
||||
subquery = subquery.subquery()
|
||||
|
||||
|
@ -847,6 +882,9 @@ class TSQL(Dialect):
|
|||
select_into.set("into", exp.Into(this=table))
|
||||
select_into.set("with", ctas_with)
|
||||
|
||||
if like_property:
|
||||
select_into.limit(0, copy=False)
|
||||
|
||||
sql = self.sql(select_into)
|
||||
|
||||
if exists:
|
||||
|
@ -937,9 +975,19 @@ class TSQL(Dialect):
|
|||
return f"CONSTRAINT {this} {expressions}"
|
||||
|
||||
def length_sql(self, expression: exp.Length) -> str:
|
||||
return self._uncast_text(expression, "LEN")
|
||||
|
||||
def right_sql(self, expression: exp.Right) -> str:
|
||||
return self._uncast_text(expression, "RIGHT")
|
||||
|
||||
def left_sql(self, expression: exp.Left) -> str:
|
||||
return self._uncast_text(expression, "LEFT")
|
||||
|
||||
def _uncast_text(self, expression: exp.Expression, name: str) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
|
||||
this_sql = self.sql(this, "this")
|
||||
else:
|
||||
this_sql = self.sql(this)
|
||||
return self.func("LEN", this_sql)
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
return self.func(name, this_sql, expression_sql if expression_sql else None)
|
||||
|
|
|
@ -10,7 +10,6 @@ import logging
|
|||
import time
|
||||
import typing as t
|
||||
|
||||
from sqlglot import maybe_parse
|
||||
from sqlglot.errors import ExecuteError
|
||||
from sqlglot.executor.python import PythonExecutor
|
||||
from sqlglot.executor.table import Table, ensure_tables
|
||||
|
@ -23,7 +22,6 @@ logger = logging.getLogger("sqlglot")
|
|||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.executor.table import Tables
|
||||
from sqlglot.expressions import Expression
|
||||
from sqlglot.schema import Schema
|
||||
|
||||
|
|
|
@ -44,9 +44,9 @@ class Context:
|
|||
|
||||
for other in self.tables.values():
|
||||
if self._table.columns != other.columns:
|
||||
raise Exception(f"Columns are different.")
|
||||
raise Exception("Columns are different.")
|
||||
if len(self._table.rows) != len(other.rows):
|
||||
raise Exception(f"Rows are different.")
|
||||
raise Exception("Rows are different.")
|
||||
|
||||
return self._table
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from functools import wraps
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.generator import Generator
|
||||
from sqlglot.helper import PYTHON_VERSION
|
||||
from sqlglot.helper import PYTHON_VERSION, is_int, seq_get
|
||||
|
||||
|
||||
class reverse_key:
|
||||
|
@ -143,6 +143,22 @@ def arrayjoin(this, expression, null=None):
|
|||
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
|
||||
|
||||
|
||||
@null_if_any("this", "expression")
|
||||
def jsonextract(this, expression):
|
||||
for path_segment in expression:
|
||||
if isinstance(this, dict):
|
||||
this = this.get(path_segment)
|
||||
elif isinstance(this, list) and is_int(path_segment):
|
||||
this = seq_get(this, int(path_segment))
|
||||
else:
|
||||
raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
|
||||
|
||||
if this is None:
|
||||
break
|
||||
|
||||
return this
|
||||
|
||||
|
||||
ENV = {
|
||||
"exp": exp,
|
||||
# aggs
|
||||
|
@ -175,12 +191,12 @@ ENV = {
|
|||
"DOT": null_if_any(lambda e, this: e[this]),
|
||||
"EQ": null_if_any(lambda this, e: this == e),
|
||||
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
||||
"GETPATH": null_if_any(lambda this, e: this.get(e)),
|
||||
"GT": null_if_any(lambda this, e: this > e),
|
||||
"GTE": null_if_any(lambda this, e: this >= e),
|
||||
"IF": lambda predicate, true, false: true if predicate else false,
|
||||
"INTDIV": null_if_any(lambda e, this: e // this),
|
||||
"INTERVAL": interval,
|
||||
"JSONEXTRACT": jsonextract,
|
||||
"LEFT": null_if_any(lambda this, e: this[:e]),
|
||||
"LIKE": null_if_any(
|
||||
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
|
||||
|
|
|
@ -9,7 +9,7 @@ from sqlglot.errors import ExecuteError
|
|||
from sqlglot.executor.context import Context
|
||||
from sqlglot.executor.env import ENV
|
||||
from sqlglot.executor.table import RowReader, Table
|
||||
from sqlglot.helper import csv_reader, subclasses
|
||||
from sqlglot.helper import csv_reader, ensure_list, subclasses
|
||||
|
||||
|
||||
class PythonExecutor:
|
||||
|
@ -368,7 +368,7 @@ def _rename(self, e):
|
|||
|
||||
if isinstance(e, exp.Func) and e.is_var_len_args:
|
||||
*head, tail = values
|
||||
return self.func(e.key, *head, *tail)
|
||||
return self.func(e.key, *head, *ensure_list(tail))
|
||||
|
||||
return self.func(e.key, *values)
|
||||
except Exception as ex:
|
||||
|
@ -429,18 +429,24 @@ class Python(Dialect):
|
|||
exp.Between: _rename,
|
||||
exp.Boolean: lambda self, e: "True" if e.this else "False",
|
||||
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
|
||||
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||
exp.Column: lambda self,
|
||||
e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
|
||||
exp.Concat: lambda self, e: self.func(
|
||||
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
|
||||
),
|
||||
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
|
||||
exp.Div: _div_sql,
|
||||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
|
||||
exp.Extract: lambda self,
|
||||
e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self,
|
||||
e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
|
||||
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
|
||||
exp.Is: lambda self, e: (
|
||||
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
|
||||
),
|
||||
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
|
||||
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
|
||||
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
|
||||
exp.Lambda: _lambda_sql,
|
||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||
exp.Null: lambda *_: "None",
|
||||
|
|
|
@ -29,6 +29,7 @@ from sqlglot.helper import (
|
|||
camel_to_snake_case,
|
||||
ensure_collection,
|
||||
ensure_list,
|
||||
is_int,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
|
@ -175,13 +176,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
Checks whether a Literal expression is an integer.
|
||||
"""
|
||||
if self.is_number:
|
||||
try:
|
||||
int(self.name)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
return self.is_number and is_int(self.name)
|
||||
|
||||
@property
|
||||
def is_star(self) -> bool:
|
||||
|
@ -493,8 +488,8 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
A AND B AND C -> [A, B, C]
|
||||
"""
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
|
||||
if not type(node) is self.__class__:
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
|
||||
if type(node) is not self.__class__:
|
||||
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -553,10 +548,12 @@ class Expression(metaclass=_Expression):
|
|||
return new_node
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: E) -> E: ...
|
||||
def replace(self, expression: E) -> E:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: None) -> None: ...
|
||||
def replace(self, expression: None) -> None:
|
||||
...
|
||||
|
||||
def replace(self, expression):
|
||||
"""
|
||||
|
@ -610,7 +607,8 @@ class Expression(metaclass=_Expression):
|
|||
>>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
|
||||
'SELECT x, z FROM y'
|
||||
"""
|
||||
assert isinstance(self, type_)
|
||||
if not isinstance(self, type_):
|
||||
raise AssertionError(f"{self} is not {type_}.")
|
||||
return self
|
||||
|
||||
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
|
||||
|
@ -1133,6 +1131,7 @@ class SetItem(Expression):
|
|||
class Show(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"history": False,
|
||||
"terse": False,
|
||||
"target": False,
|
||||
"offset": False,
|
||||
|
@ -1676,7 +1675,6 @@ class Index(Expression):
|
|||
"amp": False, # teradata
|
||||
"include": False,
|
||||
"partition_by": False, # teradata
|
||||
"where": False, # postgres partial indexes
|
||||
}
|
||||
|
||||
|
||||
|
@ -2573,7 +2571,7 @@ class HistoricalData(Expression):
|
|||
|
||||
class Table(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"this": False,
|
||||
"alias": False,
|
||||
"db": False,
|
||||
"catalog": False,
|
||||
|
@ -3664,6 +3662,7 @@ class DataType(Expression):
|
|||
BINARY = auto()
|
||||
BIT = auto()
|
||||
BOOLEAN = auto()
|
||||
BPCHAR = auto()
|
||||
CHAR = auto()
|
||||
DATE = auto()
|
||||
DATE32 = auto()
|
||||
|
@ -3805,6 +3804,7 @@ class DataType(Expression):
|
|||
dtype: DATA_TYPE,
|
||||
dialect: DialectType = None,
|
||||
udt: bool = False,
|
||||
copy: bool = True,
|
||||
**kwargs,
|
||||
) -> DataType:
|
||||
"""
|
||||
|
@ -3815,7 +3815,8 @@ class DataType(Expression):
|
|||
dialect: the dialect to use for parsing `dtype`, in case it's a string.
|
||||
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
|
||||
DataType, thus creating a user-defined type.
|
||||
kawrgs: additional arguments to pass in the constructor of DataType.
|
||||
copy: whether or not to copy the data type.
|
||||
kwargs: additional arguments to pass in the constructor of DataType.
|
||||
|
||||
Returns:
|
||||
The constructed DataType object.
|
||||
|
@ -3837,7 +3838,7 @@ class DataType(Expression):
|
|||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
elif isinstance(dtype, DataType):
|
||||
return dtype
|
||||
return maybe_copy(dtype, copy)
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||
|
||||
|
@ -3855,7 +3856,7 @@ class DataType(Expression):
|
|||
True, if and only if there is a type in `dtypes` which is equal to this DataType.
|
||||
"""
|
||||
for dtype in dtypes:
|
||||
other = DataType.build(dtype, udt=True)
|
||||
other = DataType.build(dtype, copy=False, udt=True)
|
||||
|
||||
if (
|
||||
other.expressions
|
||||
|
@ -4001,7 +4002,7 @@ class Dot(Binary):
|
|||
def build(self, expressions: t.Sequence[Expression]) -> Dot:
|
||||
"""Build a Dot object with a sequence of expressions."""
|
||||
if len(expressions) < 2:
|
||||
raise ValueError(f"Dot requires >= 2 expressions.")
|
||||
raise ValueError("Dot requires >= 2 expressions.")
|
||||
|
||||
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
|
||||
|
||||
|
@ -4128,10 +4129,6 @@ class Sub(Binary):
|
|||
pass
|
||||
|
||||
|
||||
class ArrayOverlaps(Binary):
|
||||
pass
|
||||
|
||||
|
||||
# Unary Expressions
|
||||
# (NOT a)
|
||||
class Unary(Condition):
|
||||
|
@ -4469,6 +4466,10 @@ class ArrayJoin(Func):
|
|||
arg_types = {"this": True, "expression": True, "null": False}
|
||||
|
||||
|
||||
class ArrayOverlaps(Binary, Func):
|
||||
pass
|
||||
|
||||
|
||||
class ArraySize(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
@ -4490,15 +4491,37 @@ class Avg(AggFunc):
|
|||
|
||||
|
||||
class AnyValue(AggFunc):
|
||||
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
|
||||
arg_types = {"this": True, "having": False, "max": False}
|
||||
|
||||
|
||||
class First(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
class Lag(AggFunc):
|
||||
arg_types = {"this": True, "offset": False, "default": False}
|
||||
|
||||
|
||||
class Last(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
class Lead(AggFunc):
|
||||
arg_types = {"this": True, "offset": False, "default": False}
|
||||
|
||||
|
||||
# some dialects have a distinction between first and first_value, usually first is an aggregate func
|
||||
# and first_value is a window func
|
||||
class First(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class Last(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class FirstValue(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class LastValue(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class NthValue(AggFunc):
|
||||
arg_types = {"this": True, "offset": True}
|
||||
|
||||
|
||||
class Case(Func):
|
||||
|
@ -4611,7 +4634,7 @@ class CurrentTime(Func):
|
|||
|
||||
|
||||
class CurrentTimestamp(Func):
|
||||
arg_types = {"this": False}
|
||||
arg_types = {"this": False, "transaction": False}
|
||||
|
||||
|
||||
class CurrentUser(Func):
|
||||
|
@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit):
|
|||
|
||||
|
||||
class TimestampDiff(Func, TimeUnit):
|
||||
_sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"]
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
|
@ -4857,6 +4881,59 @@ class IsInf(Func):
|
|||
_sql_names = ["IS_INF", "ISINF"]
|
||||
|
||||
|
||||
class JSONPath(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
last_segment = self.expressions[-1].this
|
||||
return last_segment if isinstance(last_segment, str) else ""
|
||||
|
||||
|
||||
class JSONPathPart(Expression):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class JSONPathFilter(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathKey(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathRecursive(JSONPathPart):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class JSONPathRoot(JSONPathPart):
|
||||
pass
|
||||
|
||||
|
||||
class JSONPathScript(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathSlice(JSONPathPart):
|
||||
arg_types = {"start": False, "end": False, "step": False}
|
||||
|
||||
|
||||
class JSONPathSelector(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathSubscript(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathUnion(JSONPathPart):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class JSONPathWildcard(JSONPathPart):
|
||||
pass
|
||||
|
||||
|
||||
class FormatJson(Expression):
|
||||
pass
|
||||
|
||||
|
@ -4940,18 +5017,30 @@ class JSONBContains(Binary):
|
|||
|
||||
|
||||
class JSONExtract(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "expressions": False}
|
||||
_sql_names = ["JSON_EXTRACT"]
|
||||
is_var_len_args = True
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name if not self.expressions else ""
|
||||
|
||||
|
||||
class JSONExtractScalar(JSONExtract):
|
||||
class JSONExtractScalar(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "expressions": False}
|
||||
_sql_names = ["JSON_EXTRACT_SCALAR"]
|
||||
is_var_len_args = True
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name
|
||||
|
||||
|
||||
class JSONBExtract(JSONExtract):
|
||||
class JSONBExtract(Binary, Func):
|
||||
_sql_names = ["JSONB_EXTRACT"]
|
||||
|
||||
|
||||
class JSONBExtractScalar(JSONExtract):
|
||||
class JSONBExtractScalar(Binary, Func):
|
||||
_sql_names = ["JSONB_EXTRACT_SCALAR"]
|
||||
|
||||
|
||||
|
@ -4972,15 +5061,6 @@ class ParseJSON(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/get_path
|
||||
class GetPath(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name
|
||||
|
||||
|
||||
class Least(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -5476,6 +5556,8 @@ def _norm_arg(arg):
|
|||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
|
||||
|
||||
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
|
||||
|
||||
|
||||
# Helpers
|
||||
@t.overload
|
||||
|
@ -5487,7 +5569,8 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E: ...
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
|
@ -5499,7 +5582,8 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E: ...
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
def maybe_parse(
|
||||
|
@ -5539,7 +5623,7 @@ def maybe_parse(
|
|||
return sql_or_expression
|
||||
|
||||
if sql_or_expression is None:
|
||||
raise ParseError(f"SQL cannot be None")
|
||||
raise ParseError("SQL cannot be None")
|
||||
|
||||
import sqlglot
|
||||
|
||||
|
@ -5551,11 +5635,13 @@ def maybe_parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None: ...
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E: ...
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
...
|
||||
|
||||
|
||||
def maybe_copy(instance, copy=True):
|
||||
|
@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
|
|||
return Paren(this=maybe_parse(expression, copy=copy))
|
||||
|
||||
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(
|
||||
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
|
||||
) -> Identifier: ...
|
||||
) -> Identifier:
|
||||
...
|
||||
|
||||
|
||||
def to_identifier(name, quoted=None, copy=True):
|
||||
|
@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: None, **kwargs) -> None: ...
|
||||
def to_table(sql_path: None, **kwargs) -> None:
|
||||
...
|
||||
|
||||
|
||||
def to_table(
|
||||
|
@ -6460,7 +6550,7 @@ def column(
|
|||
return this
|
||||
|
||||
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
|
@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
|||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The datatype to cast to.
|
||||
copy: Whether or not to copy the supplied expressions.
|
||||
|
||||
Returns:
|
||||
The new Cast instance.
|
||||
"""
|
||||
expression = maybe_parse(expression, **opts)
|
||||
data_type = DataType.build(to, **opts)
|
||||
expression = maybe_parse(expression, copy=copy, **opts)
|
||||
data_type = DataType.build(to, copy=copy, **opts)
|
||||
expression = Cast(this=expression, to=data_type)
|
||||
expression.type = data_type
|
||||
return expression
|
||||
|
|
|
@ -9,6 +9,7 @@ from functools import reduce
|
|||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -21,7 +22,18 @@ logger = logging.getLogger("sqlglot")
|
|||
ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
|
||||
|
||||
|
||||
class Generator:
|
||||
class _Generator(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
# Remove transforms that correspond to unsupported JSONPathPart expressions
|
||||
for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS:
|
||||
klass.TRANSFORMS.pop(part, None)
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
class Generator(metaclass=_Generator):
|
||||
"""
|
||||
Generator converts a given syntax tree to the corresponding SQL string.
|
||||
|
||||
|
@ -58,19 +70,23 @@ class Generator:
|
|||
Default: True
|
||||
"""
|
||||
|
||||
TRANSFORMS = {
|
||||
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
|
||||
**JSON_PATH_PART_TRANSFORMS,
|
||||
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
|
||||
exp.CaseSpecificColumnConstraint: lambda self,
|
||||
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self,
|
||||
e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.ClusteredColumnConstraint: lambda self,
|
||||
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
|
||||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
|
||||
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
|
||||
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
|
||||
|
@ -85,29 +101,33 @@ class Generator:
|
|||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
|
||||
exp.NonClusteredColumnConstraint: lambda self,
|
||||
e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
|
||||
exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
|
||||
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnCommitProperty: lambda self,
|
||||
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
|
||||
exp.RemoteWithConnectionModelProperty: lambda self,
|
||||
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlReadWriteProperty: lambda self, e: e.name,
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.SqlSecurityProperty: lambda self,
|
||||
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
|
||||
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE",
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||
|
@ -117,6 +137,10 @@ class Generator:
|
|||
# True: Full Support, None: No support, False: No support in window specifications
|
||||
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
|
||||
|
||||
# Whether or not ignore nulls is inside the agg or outside.
|
||||
# FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER
|
||||
IGNORE_NULLS_IN_FUNC = False
|
||||
|
||||
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
|
||||
LOCKING_READS_SUPPORTED = False
|
||||
|
||||
|
@ -266,6 +290,24 @@ class Generator:
|
|||
# Whether or not UNLOGGED tables can be created
|
||||
SUPPORTS_UNLOGGED_TABLES = False
|
||||
|
||||
# Whether or not the CREATE TABLE LIKE statement is supported
|
||||
SUPPORTS_CREATE_TABLE_LIKE = True
|
||||
|
||||
# Whether or not the LikeProperty needs to be specified inside of the schema clause
|
||||
LIKE_PROPERTY_INSIDE_SCHEMA = False
|
||||
|
||||
# Whether or not the JSON extraction operators expect a value of type JSON
|
||||
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
|
||||
|
||||
# Whether or not bracketed keys like ["foo"] are supported in JSON paths
|
||||
JSON_PATH_BRACKETED_KEY_SUPPORTED = True
|
||||
|
||||
# Whether or not to escape keys using single quotes in JSON paths
|
||||
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
|
||||
|
||||
# The JSONPathPart expressions supported by this dialect
|
||||
SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy()
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -641,8 +683,6 @@ class Generator:
|
|||
|
||||
if callable(transform):
|
||||
sql = transform(self, expression)
|
||||
elif transform:
|
||||
sql = transform
|
||||
elif isinstance(expression, exp.Expression):
|
||||
exp_handler_name = f"{expression.key}_sql"
|
||||
|
||||
|
@ -802,7 +842,7 @@ class Generator:
|
|||
desc = expression.args.get("desc")
|
||||
if desc is not None:
|
||||
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
|
||||
return f"PRIMARY KEY"
|
||||
return "PRIMARY KEY"
|
||||
|
||||
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1218,9 +1258,21 @@ class Generator:
|
|||
return f"{property_name}={self.sql(expression, 'this')}"
|
||||
|
||||
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
|
||||
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
|
||||
options = f" {options}" if options else ""
|
||||
return f"LIKE {self.sql(expression, 'this')}{options}"
|
||||
if self.SUPPORTS_CREATE_TABLE_LIKE:
|
||||
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
|
||||
options = f" {options}" if options else ""
|
||||
|
||||
like = f"LIKE {self.sql(expression, 'this')}{options}"
|
||||
if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance(expression.parent, exp.Schema):
|
||||
like = f"({like})"
|
||||
|
||||
return like
|
||||
|
||||
if expression.expressions:
|
||||
self.unsupported("Transpilation of LIKE property options is unsupported")
|
||||
|
||||
select = exp.select("*").from_(expression.this).limit(0)
|
||||
return f"AS {self.sql(select)}"
|
||||
|
||||
def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str:
|
||||
no = "NO " if expression.args.get("no") else ""
|
||||
|
@ -2367,6 +2419,31 @@ class Generator:
|
|||
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
|
||||
|
||||
def jsonpath_sql(self, expression: exp.JSONPath) -> str:
|
||||
path = self.expressions(expression, sep="", flat=True).lstrip(".")
|
||||
return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
|
||||
|
||||
def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
|
||||
if isinstance(expression, exp.JSONPathPart):
|
||||
transform = self.TRANSFORMS.get(expression.__class__)
|
||||
if not callable(transform):
|
||||
self.unsupported(f"Unsupported JSONPathPart type {expression.__class__.__name__}")
|
||||
return ""
|
||||
|
||||
return transform(self, expression)
|
||||
|
||||
if isinstance(expression, int):
|
||||
return str(expression)
|
||||
|
||||
if self.JSON_PATH_SINGLE_QUOTE_ESCAPE:
|
||||
escaped = expression.replace("'", "\\'")
|
||||
escaped = f"\\'{expression}\\'"
|
||||
else:
|
||||
escaped = expression.replace('"', '\\"')
|
||||
escaped = f'"{escaped}"'
|
||||
|
||||
return escaped
|
||||
|
||||
def formatjson_sql(self, expression: exp.FormatJson) -> str:
|
||||
return f"{self.sql(expression, 'this')} FORMAT JSON"
|
||||
|
||||
|
@ -2620,6 +2697,9 @@ class Generator:
|
|||
zone = self.sql(expression, "this")
|
||||
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
|
||||
|
||||
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
|
||||
return self.func("CURRENT_TIMESTAMP", expression.this)
|
||||
|
||||
def collate_sql(self, expression: exp.Collate) -> str:
|
||||
if self.COLLATE_IS_FUNC:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
@ -2761,10 +2841,20 @@ class Generator:
|
|||
return f"DISTINCT{this}{on}"
|
||||
|
||||
def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
|
||||
return f"{self.sql(expression, 'this')} IGNORE NULLS"
|
||||
return self._embed_ignore_nulls(expression, "IGNORE NULLS")
|
||||
|
||||
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
|
||||
return f"{self.sql(expression, 'this')} RESPECT NULLS"
|
||||
return self._embed_ignore_nulls(expression, "RESPECT NULLS")
|
||||
|
||||
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
|
||||
if self.IGNORE_NULLS_IN_FUNC:
|
||||
this = expression.find(exp.AggFunc)
|
||||
if this:
|
||||
sql = self.sql(this)
|
||||
sql = sql[:-1] + f" {text})"
|
||||
return sql
|
||||
|
||||
return f"{self.sql(expression, 'this')} {text}"
|
||||
|
||||
def intdiv_sql(self, expression: exp.IntDiv) -> str:
|
||||
return self.sql(
|
||||
|
@ -2935,7 +3025,7 @@ class Generator:
|
|||
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
|
||||
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
|
||||
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
|
||||
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(arg_sqls)
|
||||
|
||||
def text_width(self, args: t.Iterable) -> int:
|
||||
|
@ -3279,6 +3369,22 @@ class Generator:
|
|||
|
||||
return self.func("LAST_DAY", expression.this)
|
||||
|
||||
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.JSONPathWildcard):
|
||||
this = self.json_path_part(this)
|
||||
return f".{this}" if this else ""
|
||||
|
||||
if exp.SAFE_IDENTIFIER_RE.match(this):
|
||||
return f".{this}"
|
||||
|
||||
this = self.json_path_part(this)
|
||||
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
|
||||
|
||||
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
|
||||
this = self.json_path_part(expression.this)
|
||||
return f"[{this}]" if this else ""
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
|
|
@ -53,11 +53,13 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
|
||||
def ensure_list(value: t.Collection[T]) -> t.List[T]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_list(value: T) -> t.List[T]: ...
|
||||
def ensure_list(value: T) -> t.List[T]:
|
||||
...
|
||||
|
||||
|
||||
def ensure_list(value):
|
||||
|
@ -79,11 +81,13 @@ def ensure_list(value):
|
|||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
|
||||
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def ensure_collection(value: T) -> t.Collection[T]: ...
|
||||
def ensure_collection(value: T) -> t.Collection[T]:
|
||||
...
|
||||
|
||||
|
||||
def ensure_collection(value):
|
||||
|
@ -232,7 +236,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
|
|||
|
||||
for node, deps in tuple(dag.items()):
|
||||
for dep in deps:
|
||||
if not dep in dag:
|
||||
if dep not in dag:
|
||||
dag[dep] = set()
|
||||
|
||||
while dag:
|
||||
|
@ -316,6 +320,14 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
|
|||
return new
|
||||
|
||||
|
||||
def is_int(text: str) -> bool:
|
||||
try:
|
||||
int(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def name_sequence(prefix: str) -> t.Callable[[], str]:
|
||||
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
|
||||
sequence = count()
|
||||
|
|
|
@ -2,8 +2,8 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
import sqlglot.expressions as exp
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.expressions import SAFE_IDENTIFIER_RE
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -36,20 +36,8 @@ class JSONPathTokenizer(Tokenizer):
|
|||
STRING_ESCAPES = ["\\"]
|
||||
|
||||
|
||||
JSONPathNode = t.Dict[str, t.Any]
|
||||
|
||||
|
||||
def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
|
||||
node = {"kind": kind, **kwargs}
|
||||
|
||||
if value is not None:
|
||||
node["value"] = value
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def parse(path: str) -> t.List[JSONPathNode]:
|
||||
"""Takes in a JSONPath string and converts into a list of nodes."""
|
||||
def parse(path: str) -> exp.JSONPath:
|
||||
"""Takes in a JSON path string and parses it into a JSONPath expression."""
|
||||
tokens = JSONPathTokenizer().tokenize(path)
|
||||
size = len(tokens)
|
||||
|
||||
|
@ -89,7 +77,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
if token:
|
||||
return token.text
|
||||
if _match(TokenType.STAR):
|
||||
return _node("wildcard")
|
||||
return exp.JSONPathWildcard()
|
||||
if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
|
||||
script = _prev().text == "("
|
||||
start = i
|
||||
|
@ -100,9 +88,9 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
if _curr() in (TokenType.R_BRACKET, None):
|
||||
break
|
||||
_advance()
|
||||
return _node(
|
||||
"script" if script else "filter", path[tokens[start].start : tokens[i].end]
|
||||
)
|
||||
|
||||
expr_type = exp.JSONPathScript if script else exp.JSONPathFilter
|
||||
return expr_type(this=path[tokens[start].start : tokens[i].end])
|
||||
|
||||
number = "-" if _match(TokenType.DASH) else ""
|
||||
|
||||
|
@ -112,6 +100,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
|
||||
if number:
|
||||
return int(number)
|
||||
|
||||
return False
|
||||
|
||||
def _parse_slice() -> t.Any:
|
||||
|
@ -121,9 +110,10 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
|
||||
if end is None and step is None:
|
||||
return start
|
||||
return _node("slice", start=start, end=end, step=step)
|
||||
|
||||
def _parse_bracket() -> JSONPathNode:
|
||||
return exp.JSONPathSlice(start=start, end=end, step=step)
|
||||
|
||||
def _parse_bracket() -> exp.JSONPathPart:
|
||||
literal = _parse_slice()
|
||||
|
||||
if isinstance(literal, str) or literal is not False:
|
||||
|
@ -136,13 +126,15 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
|
||||
if len(indexes) == 1:
|
||||
if isinstance(literal, str):
|
||||
node = _node("key", indexes[0])
|
||||
elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
|
||||
node = _node("selector", indexes[0])
|
||||
node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0])
|
||||
elif isinstance(literal, exp.JSONPathPart) and isinstance(
|
||||
literal, (exp.JSONPathScript, exp.JSONPathFilter)
|
||||
):
|
||||
node = exp.JSONPathSelector(this=indexes[0])
|
||||
else:
|
||||
node = _node("subscript", indexes[0])
|
||||
node = exp.JSONPathSubscript(this=indexes[0])
|
||||
else:
|
||||
node = _node("union", indexes)
|
||||
node = exp.JSONPathUnion(expressions=indexes)
|
||||
else:
|
||||
raise ParseError(_error("Cannot have empty segment"))
|
||||
|
||||
|
@ -150,66 +142,56 @@ def parse(path: str) -> t.List[JSONPathNode]:
|
|||
|
||||
return node
|
||||
|
||||
nodes = []
|
||||
# We canonicalize the JSON path AST so that it always starts with a
|
||||
# "root" element, so paths like "field" will be generated as "$.field"
|
||||
_match(TokenType.DOLLAR)
|
||||
expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
|
||||
|
||||
while _curr():
|
||||
if _match(TokenType.DOLLAR):
|
||||
nodes.append(_node("root"))
|
||||
elif _match(TokenType.DOT):
|
||||
if _match(TokenType.DOT) or _match(TokenType.COLON):
|
||||
recursive = _prev().text == ".."
|
||||
value = _match(TokenType.VAR) or _match(TokenType.STAR)
|
||||
nodes.append(
|
||||
_node("recursive" if recursive else "child", value=value.text if value else None)
|
||||
)
|
||||
|
||||
if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
|
||||
value: t.Optional[str | exp.JSONPathWildcard] = _prev().text
|
||||
elif _match(TokenType.STAR):
|
||||
value = exp.JSONPathWildcard()
|
||||
else:
|
||||
value = None
|
||||
|
||||
if recursive:
|
||||
expressions.append(exp.JSONPathRecursive(this=value))
|
||||
elif value:
|
||||
expressions.append(exp.JSONPathKey(this=value))
|
||||
else:
|
||||
raise ParseError(_error("Expected key name or * after DOT"))
|
||||
elif _match(TokenType.L_BRACKET):
|
||||
nodes.append(_parse_bracket())
|
||||
elif _match(TokenType.VAR):
|
||||
nodes.append(_node("key", _prev().text))
|
||||
expressions.append(_parse_bracket())
|
||||
elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
|
||||
expressions.append(exp.JSONPathKey(this=_prev().text))
|
||||
elif _match(TokenType.STAR):
|
||||
nodes.append(_node("wildcard"))
|
||||
elif _match(TokenType.PARAMETER):
|
||||
nodes.append(_node("current"))
|
||||
expressions.append(exp.JSONPathWildcard())
|
||||
else:
|
||||
raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
|
||||
|
||||
return nodes
|
||||
return exp.JSONPath(expressions=expressions)
|
||||
|
||||
|
||||
MAPPING = {
|
||||
"child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
|
||||
"filter": lambda n: f"?{n['value']}",
|
||||
"key": lambda n: (
|
||||
f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
|
||||
),
|
||||
"recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
|
||||
"root": lambda _: "$",
|
||||
"script": lambda n: f"({n['value']}",
|
||||
"slice": lambda n: ":".join(
|
||||
"" if p is False else generate([p])
|
||||
for p in [n["start"], n["end"], n["step"]]
|
||||
JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
|
||||
exp.JSONPathFilter: lambda _, e: f"?{e.this}",
|
||||
exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e),
|
||||
exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}",
|
||||
exp.JSONPathRoot: lambda *_: "$",
|
||||
exp.JSONPathScript: lambda _, e: f"({e.this}",
|
||||
exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]",
|
||||
exp.JSONPathSlice: lambda self, e: ":".join(
|
||||
"" if p is False else self.json_path_part(p)
|
||||
for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")]
|
||||
if p is not None
|
||||
),
|
||||
"selector": lambda n: f"[{generate([n['value']])}]",
|
||||
"subscript": lambda n: f"[{generate([n['value']])}]",
|
||||
"union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
|
||||
"wildcard": lambda _: "*",
|
||||
exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e),
|
||||
exp.JSONPathUnion: lambda self,
|
||||
e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]",
|
||||
exp.JSONPathWildcard: lambda *_: "*",
|
||||
}
|
||||
|
||||
|
||||
def generate(
|
||||
nodes: t.List[JSONPathNode],
|
||||
mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
|
||||
) -> str:
|
||||
mapping = MAPPING if mapping is None else mapping
|
||||
path = []
|
||||
|
||||
for node in nodes:
|
||||
if isinstance(node, dict):
|
||||
path.append(mapping[node["kind"]](node))
|
||||
elif isinstance(node, str):
|
||||
escaped = node.replace('"', '\\"')
|
||||
path.append(f'"{escaped}"')
|
||||
else:
|
||||
path.append(str(node))
|
||||
|
||||
return "".join(path)
|
||||
ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# ruff: noqa: F401
|
||||
|
||||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.scope import (
|
||||
Scope,
|
||||
|
|
|
@ -10,11 +10,13 @@ if t.TYPE_CHECKING:
|
|||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
|
||||
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
|
||||
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
|
||||
...
|
||||
|
||||
|
||||
def normalize_identifiers(expression, dialect=None):
|
||||
|
|
|
@ -8,10 +8,10 @@ from sqlglot.schema import ensure_schema
|
|||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = lambda is_agg: alias(
|
||||
exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
|
||||
)
|
||||
def default_selection(is_agg: bool) -> exp.Alias:
|
||||
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
|
||||
|
||||
|
||||
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
|
||||
|
@ -129,7 +129,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION(is_agg))
|
||||
new_selections.append(default_selection(is_agg))
|
||||
|
||||
scope.expression.select(*new_selections, append=False, copy=False)
|
||||
|
||||
|
|
|
@ -104,7 +104,6 @@ def simplify(
|
|||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
||||
return node
|
||||
|
||||
expression = while_changing(expression, _simplify)
|
||||
|
@ -174,16 +173,20 @@ def simplify_not(expression):
|
|||
if isinstance(this, exp.Paren):
|
||||
condition = this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
return exp.paren(
|
||||
exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
)
|
||||
if isinstance(condition, exp.Or):
|
||||
return exp.and_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
return exp.paren(
|
||||
exp.and_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
exp.not_(condition.right, copy=False),
|
||||
copy=False,
|
||||
)
|
||||
)
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
|
@ -490,7 +493,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
if isinstance(expression, COMPARISONS):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if not l.__class__ in INVERSE_OPS:
|
||||
if l.__class__ not in INVERSE_OPS:
|
||||
return expression
|
||||
|
||||
if r.is_number:
|
||||
|
@ -714,8 +717,7 @@ def simplify_concat(expression):
|
|||
"""Reduces all groups that contain string literals by concatenating them."""
|
||||
if not isinstance(expression, CONCATS) or (
|
||||
# We can't reduce a CONCAT_WS call if we don't statically know the separator
|
||||
isinstance(expression, exp.ConcatWs)
|
||||
and not expression.expressions[0].is_string
|
||||
isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
|
||||
):
|
||||
return expression
|
||||
|
||||
|
|
|
@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
|
|||
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
|
||||
|
||||
|
||||
def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
|
||||
def _parser(args: t.List, dialect: Dialect) -> E:
|
||||
expression = expr_type(
|
||||
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
|
||||
)
|
||||
if len(args) > 2 and expr_type is exp.JSONExtract:
|
||||
expression.set("expressions", args[2:])
|
||||
|
||||
return expression
|
||||
|
||||
return _parser
|
||||
|
||||
|
||||
class _Parser(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
@ -102,6 +115,9 @@ class Parser(metaclass=_Parser):
|
|||
to=exp.DataType(this=exp.DataType.Type.TEXT),
|
||||
),
|
||||
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
|
||||
"JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract),
|
||||
"JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar),
|
||||
"JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar),
|
||||
"LIKE": parse_like,
|
||||
"LOG": parse_logarithm,
|
||||
"TIME_TO_TIME_STR": lambda args: exp.Cast(
|
||||
|
@ -175,6 +191,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.NCHAR,
|
||||
TokenType.VARCHAR,
|
||||
TokenType.NVARCHAR,
|
||||
TokenType.BPCHAR,
|
||||
TokenType.TEXT,
|
||||
TokenType.MEDIUMTEXT,
|
||||
TokenType.LONGTEXT,
|
||||
|
@ -295,6 +312,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ASC,
|
||||
TokenType.AUTO_INCREMENT,
|
||||
TokenType.BEGIN,
|
||||
TokenType.BPCHAR,
|
||||
TokenType.CACHE,
|
||||
TokenType.CASE,
|
||||
TokenType.COLLATE,
|
||||
|
@ -531,12 +549,12 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
expression=path,
|
||||
expression=self.dialect.to_json_path(path),
|
||||
),
|
||||
TokenType.DARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONExtractScalar,
|
||||
this=this,
|
||||
expression=path,
|
||||
expression=self.dialect.to_json_path(path),
|
||||
),
|
||||
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONBExtract,
|
||||
|
@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.Drop,
|
||||
comments=start.comments,
|
||||
exists=exists or self._parse_exists(),
|
||||
this=self._parse_table(schema=True),
|
||||
this=self._parse_table(
|
||||
schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
|
||||
),
|
||||
kind=kind,
|
||||
temporary=temporary,
|
||||
materialized=materialized,
|
||||
|
@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser):
|
|||
elif create_token.token_type == TokenType.INDEX:
|
||||
this = self._parse_index(index=self._parse_id_var())
|
||||
elif create_token.token_type in self.DB_CREATABLES:
|
||||
table_parts = self._parse_table_parts(schema=True)
|
||||
table_parts = self._parse_table_parts(
|
||||
schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
|
||||
)
|
||||
|
||||
# exp.Properties.Location.POST_NAME
|
||||
self._match(TokenType.COMMA)
|
||||
|
@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
|
||||
text = "ALL ROWS PER MATCH"
|
||||
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
|
||||
text += f" SHOW EMPTY MATCHES"
|
||||
text += " SHOW EMPTY MATCHES"
|
||||
elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"):
|
||||
text += f" OMIT EMPTY MATCHES"
|
||||
text += " OMIT EMPTY MATCHES"
|
||||
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
|
||||
text += f" WITH UNMATCHED ROWS"
|
||||
text += " WITH UNMATCHED ROWS"
|
||||
rows = exp.var(text)
|
||||
else:
|
||||
rows = None
|
||||
|
@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_text_seq("AFTER", "MATCH", "SKIP"):
|
||||
text = "AFTER MATCH SKIP"
|
||||
if self._match_text_seq("PAST", "LAST", "ROW"):
|
||||
text += f" PAST LAST ROW"
|
||||
text += " PAST LAST ROW"
|
||||
elif self._match_text_seq("TO", "NEXT", "ROW"):
|
||||
text += f" TO NEXT ROW"
|
||||
text += " TO NEXT ROW"
|
||||
elif self._match_text_seq("TO", "FIRST"):
|
||||
text += f" TO FIRST {self._advance_any().text}" # type: ignore
|
||||
elif self._match_text_seq("TO", "LAST"):
|
||||
|
@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser):
|
|||
or self._parse_placeholder()
|
||||
)
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
|
||||
catalog = None
|
||||
db = None
|
||||
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
|
||||
|
@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser):
|
|||
db = table
|
||||
table = self._parse_table_part(schema=schema) or ""
|
||||
|
||||
if not table:
|
||||
if is_db_reference:
|
||||
catalog = db
|
||||
db = table
|
||||
table = None
|
||||
|
||||
if not table and not is_db_reference:
|
||||
self.raise_error(f"Expected table name but got {self._curr}")
|
||||
if not db and is_db_reference:
|
||||
self.raise_error(f"Expected database name but got {self._curr}")
|
||||
|
||||
return self.expression(
|
||||
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
|
||||
|
@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser):
|
|||
joins: bool = False,
|
||||
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
|
||||
parse_bracket: bool = False,
|
||||
is_db_reference: bool = False,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
lateral = self._parse_lateral()
|
||||
if lateral:
|
||||
|
@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser):
|
|||
bracket = parse_bracket and self._parse_bracket(None)
|
||||
bracket = self.expression(exp.Table, this=bracket) if bracket else None
|
||||
this = t.cast(
|
||||
exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
|
||||
exp.Expression,
|
||||
bracket
|
||||
or self._parse_bracket(
|
||||
self._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
|
||||
),
|
||||
)
|
||||
|
||||
if schema:
|
||||
|
@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser):
|
|||
identifier = allow_identifiers and self._parse_id_var(
|
||||
any_token=False, tokens=(TokenType.VAR,)
|
||||
)
|
||||
|
||||
if identifier:
|
||||
tokens = self.dialect.tokenize(identifier.name)
|
||||
|
||||
|
@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_column_reference()
|
||||
return self._parse_column_ops(this) if this else self._parse_bracket(this)
|
||||
|
||||
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_field()
|
||||
if isinstance(this, exp.Identifier):
|
||||
this = self.expression(exp.Column, this=this)
|
||||
elif not this:
|
||||
return self._parse_bracket(this)
|
||||
return self._parse_column_ops(this)
|
||||
return this
|
||||
|
||||
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_bracket(this)
|
||||
|
@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser):
|
|||
if not field:
|
||||
self.raise_error("Expected type")
|
||||
elif op and self._curr:
|
||||
self._advance()
|
||||
value = self._prev.text
|
||||
field = (
|
||||
exp.Literal.number(value)
|
||||
if self._prev.token_type == TokenType.NUMBER
|
||||
else exp.Literal.string(value)
|
||||
)
|
||||
field = self._parse_column_reference()
|
||||
else:
|
||||
field = self._parse_field(anonymous_func=True, any_token=True)
|
||||
|
||||
|
@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser):
|
|||
options[kind] = action
|
||||
|
||||
return self.expression(
|
||||
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
|
||||
exp.ForeignKey,
|
||||
expressions=expressions,
|
||||
reference=reference,
|
||||
**options, # type: ignore
|
||||
)
|
||||
|
||||
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
|
||||
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
|
||||
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
|
||||
...
|
||||
|
||||
def _parse_json_object(self, agg=False):
|
||||
star = self._parse_star()
|
||||
|
@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser):
|
|||
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
|
||||
# and Snowflake chose to do the same for familiarity
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
|
||||
if isinstance(this, exp.AggFunc):
|
||||
ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls)
|
||||
|
||||
if ignore_respect and ignore_respect is not this:
|
||||
ignore_respect.replace(ignore_respect.this)
|
||||
this = self.expression(ignore_respect.__class__, this=this)
|
||||
|
||||
this = self._parse_respect_or_ignore_nulls(this)
|
||||
|
||||
# bigquery select from window x AS (partition by ...)
|
||||
|
@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser):
|
|||
return True
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
|
||||
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def _replace_columns_with_dots(
|
||||
self, this: t.Optional[exp.Expression]
|
||||
) -> t.Optional[exp.Expression]: ...
|
||||
) -> t.Optional[exp.Expression]:
|
||||
...
|
||||
|
||||
def _replace_columns_with_dots(self, this):
|
||||
if isinstance(this, exp.Dot):
|
||||
|
|
|
@ -125,6 +125,7 @@ class TokenType(AutoName):
|
|||
NCHAR = auto()
|
||||
VARCHAR = auto()
|
||||
NVARCHAR = auto()
|
||||
BPCHAR = auto()
|
||||
TEXT = auto()
|
||||
MEDIUMTEXT = auto()
|
||||
LONGTEXT = auto()
|
||||
|
@ -801,6 +802,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR": TokenType.NVARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"BPCHAR": TokenType.BPCHAR,
|
||||
"STR": TokenType.TEXT,
|
||||
"STRING": TokenType.TEXT,
|
||||
"TEXT": TokenType.TEXT,
|
||||
|
@ -1141,7 +1143,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
|
||||
self._advance(comment_end_size - 1)
|
||||
else:
|
||||
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
|
||||
while not self._end and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK:
|
||||
self._advance(alnum=True)
|
||||
self._comments.append(self._text[comment_start_size:])
|
||||
|
||||
|
@ -1259,7 +1261,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if base:
|
||||
try:
|
||||
int(text, base)
|
||||
except:
|
||||
except Exception:
|
||||
raise TokenError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
|
|
|
@ -485,8 +485,8 @@ def preprocess(
|
|||
expression_type = type(expression)
|
||||
|
||||
expression = transforms[0](expression)
|
||||
for t in transforms[1:]:
|
||||
expression = t(expression)
|
||||
for transform in transforms[1:]:
|
||||
expression = transform(expression)
|
||||
|
||||
_sql_handler = getattr(self, expression.key + "_sql", None)
|
||||
if _sql_handler:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue