1
0
Fork 0

Merging upstream version 21.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:20:36 +01:00
parent 3759c601a7
commit 96b10de29a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
115 changed files with 66603 additions and 60920 deletions

View file

@ -1,3 +1,4 @@
# ruff: noqa: F401
"""
.. include:: ../README.md
@ -87,11 +88,13 @@ def parse(
@t.overload
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ...
def parse_one(sql: str, *, into: t.Type[E], **opts) -> E:
...
@t.overload
def parse_one(sql: str, **opts) -> Expression: ...
def parse_one(sql: str, **opts) -> Expression:
...
def parse_one(

View file

@ -13,4 +13,5 @@ if t.TYPE_CHECKING:
A = t.TypeVar("A", bound=t.Any)
B = t.TypeVar("B", bound="sqlglot.exp.Binary")
E = t.TypeVar("E", bound="sqlglot.exp.Expression")
F = t.TypeVar("F", bound="sqlglot.exp.Func")
T = t.TypeVar("T")

View file

@ -140,10 +140,12 @@ class DataFrame:
return cte, name
@t.overload
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
...
@t.overload
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ...
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
...
def _ensure_list_of_columns(self, cols):
return Column.ensure_cols(ensure_list(cols))

View file

@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
this = Column.invoke_expression_over_column(col, expression.First)
if ignorenulls:
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
return this
def grouping_id(*cols: ColumnOrName) -> Column:
@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
this = Column.invoke_expression_over_column(col, expression.Last)
if ignorenulls:
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
return this
def monotonically_increasing_id() -> Column:
@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column:
def lag(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1:
return Column.invoke_anonymous_function(col, "LAG", offset)
return Column.invoke_anonymous_function(col, "LAG")
return Column.invoke_expression_over_column(
col, expression.Lag, offset=None if offset == 1 else offset, default=default
)
def lead(
col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1:
return Column.invoke_anonymous_function(col, "LEAD", offset)
return Column.invoke_anonymous_function(col, "LEAD")
return Column.invoke_expression_over_column(
col, expression.Lead, offset=None if offset == 1 else offset, default=default
)
def nth_value(
col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
) -> Column:
this = Column.invoke_expression_over_column(
col, expression.NthValue, offset=None if offset == 1 else offset
)
if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1:
return Column.invoke_anonymous_function(col, "NTH_VALUE", offset)
return Column.invoke_anonymous_function(col, "NTH_VALUE")
return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
return this
def ntile(n: int) -> Column:

View file

@ -1,9 +1,10 @@
# ruff: noqa: F401
"""
## Dialects
While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult
to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible
SQL transpilation framework.
SQL transpilation framework.
The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible.

View file

@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
path_to_jsonpath,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
@ -458,8 +457,10 @@ class BigQuery(Dialect):
return this
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False
) -> exp.Table:
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
if isinstance(table.this, exp.Identifier) and "." in table.name:
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
@ -474,10 +475,12 @@ class BigQuery(Dialect):
return table
@t.overload
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
...
@t.overload
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@ -536,6 +539,8 @@ class BigQuery(Dialect):
UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
JSON_KEY_VALUE_PAIR_SEP = ","
NULL_ORDERING_SUPPORTED = False
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -554,7 +559,8 @@ class BigQuery(Dialect):
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateDiff: lambda self,
e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
@ -565,7 +571,6 @@ class BigQuery(Dialect):
"DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
@ -597,12 +602,13 @@ class BigQuery(Dialect):
]
),
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: (
f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToDate: lambda self,
e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
@ -610,9 +616,10 @@ class BigQuery(Dialect):
exp.TimeFromParts: rename_func("TIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToTime: rename_func("TIME"),
@ -623,6 +630,12 @@ class BigQuery(Dialect):
exp.VariancePop: rename_func("VAR_POP"),
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",

View file

@ -8,12 +8,15 @@ from sqlglot.dialects.dialect import (
arg_max_or_min_no_count,
date_delta_sql,
inline_array_sql,
json_extract_segments,
json_path_key_only_name,
no_pivot_sql,
parse_json_extract_path,
rename_func,
var_map_sql,
)
from sqlglot.errors import ParseError
from sqlglot.helper import seq_get
from sqlglot.helper import is_int, seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@ -120,6 +123,9 @@ class ClickHouse(Dialect):
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"JSONEXTRACTSTRING": parse_json_extract_path(
exp.JSONExtractScalar, zero_based_indexing=False
),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"RANDCANONICAL": exp.Rand.from_arg_list,
@ -354,9 +360,14 @@ class ClickHouse(Dialect):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
this = super()._parse_table(
schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
schema=schema,
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
is_db_reference=is_db_reference,
)
if self._match(TokenType.FINAL):
@ -518,6 +529,12 @@ class ClickHouse(Dialect):
exp.DataType.Type.VARCHAR: "String",
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**STRING_TYPE_MAPPING,
@ -570,6 +587,10 @@ class ClickHouse(Dialect):
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
exp.JSONExtract: json_extract_segments("JSONExtractString", quoted_index=False),
exp.JSONExtractScalar: json_extract_segments("JSONExtractString", quoted_index=False),
exp.JSONPathKey: json_path_key_only_name,
exp.JSONPathRoot: lambda *_: "",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Nullif: rename_func("nullIf"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@ -579,7 +600,8 @@ class ClickHouse(Dialect):
exp.Rand: rename_func("randCanonical"),
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.StrPosition: lambda self,
e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
@ -608,6 +630,13 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return str(int(this) + 1) if is_int(this) else this
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
return f"AS {self.sql(expression, 'this')}"
def _any_to_has(
self,
expression: exp.EQ | exp.NEQ,

View file

@ -22,6 +22,7 @@ class Databricks(Spark):
"DATEADD": parse_date_delta(exp.DateAdd),
"DATE_ADD": parse_date_delta(exp.DateAdd),
"DATEDIFF": parse_date_delta(exp.DateDiff),
"TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
}
FACTOR = {
@ -48,6 +49,9 @@ class Databricks(Spark):
exp.DatetimeDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
),
exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import typing as t
from enum import Enum, auto
from functools import reduce
@ -7,7 +8,8 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
if t.TYPE_CHECKING:
from sqlglot._typing import B, E
from sqlglot._typing import B, E, F
JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
logger = logging.getLogger("sqlglot")
class Dialects(str, Enum):
@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for quotes, identifiers and the corresponding escape characters
# Delimiters for string literals and identifiers
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect):
"""
if (
isinstance(expression, exp.Identifier)
and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect):
return expression
def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"
try:
return parse_json_path(path_text)
except ParseError as e:
logger.warning(f"Invalid JSON path syntax. {str(e)}")
return path
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)
@ -500,14 +519,12 @@ def if_sql(
return _if_sql
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
this = expression.this
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
this.replace(exp.cast(this, "json"))
def arrow_json_extract_scalar_sql(
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
) -> str:
return self.binary(expression, "->>")
return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported")
return ""
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
from sqlglot.optimizer.simplify import simplify
# Makes sure the path will be evaluated correctly at runtime to include the path root.
# For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
path = expression.expression
path = exp.func(
"if",
exp.func("startswith", path, "'['"),
exp.func("concat", "'$'", path),
exp.func("concat", "'$.'", path),
)
expression.expression.replace(simplify(path))
return expression
def path_to_jsonpath(
name: str = "JSON_EXTRACT",
) -> t.Callable[[Generator, exp.GetPath], str]:
def _transform(self: Generator, expression: exp.GetPath) -> str:
return rename_func(name)(self, prepend_dollar_to_path(expression))
return _transform
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
normalize = lambda identifier: (
self.dialect.normalize_identifier(identifier).name if identifier else None
)
def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
return self.dialect.normalize_identifier(identifier).name if identifier else None
targets = {normalize(expression.this.this)}
@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
)
return self.merge_sql(expression)
def parse_json_extract_path(
expr_type: t.Type[F], zero_based_indexing: bool = True
) -> t.Callable[[t.List], F]:
def _parse_json_extract_path(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
for arg in args[1:]:
if not isinstance(arg, exp.Literal):
# We use the fallback parser because we can't really transpile non-literals safely
return expr_type.from_arg_list(args)
text = arg.name
if is_int(text):
index = int(text)
segments.append(
exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
)
else:
segments.append(exp.JSONPathKey(this=text))
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
return _parse_json_extract_path
def json_extract_segments(
name: str, quoted_index: bool = True
) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
path = expression.expression
if not isinstance(path, exp.JSONPath):
return rename_func(name)(self, expression)
segments = []
for segment in path.expressions:
path = self.sql(segment)
if path:
if isinstance(segment, exp.JSONPathPart) and (
quoted_index or not isinstance(segment, exp.JSONPathSubscript)
):
path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
segments.append(path)
return self.func(name, expression.this, *segments)
return _json_extract_segments
def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
if isinstance(expression.this, exp.JSONPathWildcard):
self.unsupported("Unsupported wildcard in JSONPathKey expression")
return expression.name

View file

@ -55,11 +55,14 @@ class Doris(MySQL):
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToUnix: lambda self,
e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.ToChar: lambda self,
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self,
e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(

View file

@ -99,6 +99,7 @@ class Drill(Dialect):
QUERY_HINTS = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_CREATE_TABLE_LIKE = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -128,10 +129,14 @@ class Drill(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.DateToDi: lambda self,
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self,
e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self,
e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@ -141,7 +146,8 @@ class Drill(Dialect):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@ -149,8 +155,10 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self,
e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name: str) -> str:

View file

@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
bool_xor_sql,
@ -18,11 +17,9 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
inline_array_sql,
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
prepend_dollar_to_path,
regexp_extract_sql,
rename_func,
str_position_sql,
@ -172,6 +169,18 @@ class DuckDB(Dialect):
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
# DuckDB also supports the JSON pointer syntax, where every path starts with a `/`.
# Additionally, it allows accessing the back of lists using the `[#-i]` syntax.
# This check ensures we'll avoid trying to parse these as JSON paths, which can
# either result in a noisy warning or in an invalid representation of the path.
path_text = path.name
if path_text.startswith("/") or "[#" in path_text:
return path
return super().to_json_path(path)
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@ -229,6 +238,8 @@ class DuckDB(Dialect):
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"JSON": exp.ParseJSON.from_arg_list,
"JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract),
"JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
@ -319,6 +330,9 @@ class DuckDB(Dialect):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_KEY_VALUE_PAIR_SEP = ","
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -350,18 +364,18 @@ class DuckDB(Dialect):
"DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.DateToDi: lambda self,
e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False),
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.DiToDate: lambda self,
e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
@ -377,7 +391,6 @@ class DuckDB(Dialect):
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
exp.RegexpReplace: lambda self, e: self.func(
"REGEXP_REPLACE",
@ -395,7 +408,8 @@ class DuckDB(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StrToUnix: lambda self,
e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
@ -405,9 +419,11 @@ class DuckDB(Dialect):
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self,
e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
@ -415,7 +431,8 @@ class DuckDB(Dialect):
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToStr: lambda self,
e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.VariancePop: rename_func("VAR_POP"),
@ -423,6 +440,13 @@ class DuckDB(Dialect):
exp.Xor: bool_xor_sql,
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
exp.JSONPathWildcard,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
@ -442,11 +466,18 @@ class DuckDB(Dialect):
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties
# https://duckdb.org/docs/sql/statements/create_table.html
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
prop: exp.Properties.Location.UNSUPPORTED
for prop in generator.Generator.PROPERTIES_LOCATION
}
# There are a few exceptions (e.g. temporary tables) which are supported or
# can be transpiled to DuckDB, so we explicitly override them accordingly
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
nano = expression.args.get("nano")
if nano is not None:
@ -486,10 +517,6 @@ class DuckDB(Dialect):
expression, sep=sep, tablesample_keyword=tablesample_keyword
)
def getpath_sql(self, expression: exp.GetPath) -> str:
expression = prepend_dollar_to_path(expression)
return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()

View file

@ -192,6 +192,18 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
return f"TO_DATE({this})"
def _parse_ignore_nulls(
exp_class: t.Type[exp.Expression],
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
this = exp_class(this=seq_get(args, 0))
if seq_get(args, 1) == exp.true():
return exp.IgnoreNulls(this=this)
return this
return _parse
class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
@ -298,8 +310,12 @@ class Hive(Dialect):
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FIRST": _parse_ignore_nulls(exp.First),
"FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LAST": _parse_ignore_nulls(exp.Last),
"LAST_VALUE": _parse_ignore_nulls(exp.LastValue),
"LOCATE": locate_to_strposition,
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
@ -429,6 +445,7 @@ class Hive(Dialect):
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@ -437,6 +454,13 @@ class Hive(Dialect):
exp.Union,
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
exp.JSONPathWildcard,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIT: "BOOLEAN",
@ -471,9 +495,12 @@ class Hive(Dialect):
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.DateToDi: lambda self,
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
exp.FileFormatProperty: lambda self,
e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
@ -502,7 +529,8 @@ class Hive(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.Split: lambda self,
e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_time_sql,
@ -514,7 +542,8 @@ class Hive(Dialect):
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _add_date_sql,
exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: _to_date_sql,
@ -528,8 +557,10 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
exp.ClusteredColumnConstraint: lambda self,
e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self,
e: f"({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "",
exp.OnProperty: lambda self, e: "",
exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY",
@ -543,6 +574,13 @@ class Hive(Dialect):
exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
}
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
if isinstance(expression.this, exp.JSONPathWildcard):
self.unsupported("Unsupported wildcard in JSONPathKey expression")
return ""
return super()._jsonpathkey_sql(expression)
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
# Hive has no temporary storage provider (there are hive settings though)
return expression

View file

@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_date_delta,
parse_date_delta_with_interval,
path_to_jsonpath,
rename_func,
strposition_to_locate_sql,
)
@ -306,6 +306,7 @@ class MySQL(Dialect):
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
"TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
"TO_DAYS": lambda args: exp.paren(
exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@ -357,6 +358,7 @@ class MySQL(Dialect):
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
"DATABASES": _show_parser("DATABASES"),
"SCHEMAS": _show_parser("DATABASES"),
"ENGINE": _show_parser("ENGINE", target=True),
"STORAGE ENGINES": _show_parser("ENGINES"),
"ENGINES": _show_parser("ENGINES"),
@ -630,6 +632,8 @@ class MySQL(Dialect):
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
@ -646,10 +650,10 @@ class MySQL(Dialect):
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.GroupConcat: lambda self,
e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
@ -672,6 +676,9 @@ class MySQL(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),

View file

@ -199,7 +199,8 @@ class Oracle(Dialect):
transforms.eliminate_qualify,
]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
@ -208,7 +209,8 @@ class Oracle(Dialect):
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
PROPERTIES_LOCATION = {

View file

@ -7,11 +7,11 @@ from sqlglot.dialects.dialect import (
DATE_ADD_OR_SUB,
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
bool_xor_sql,
datestrtodate_sql,
format_time_lambda,
json_extract_segments,
json_path_key_only_name,
max_or_greatest,
merge_without_target_sql,
min_or_least,
@ -20,6 +20,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_pivot_sql,
no_trycast_sql,
parse_json_extract_path,
parse_timestamp_trunc,
rename_func,
str_position_sql,
@ -292,6 +293,8 @@ class Postgres(Dialect):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
"JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
@ -375,8 +378,15 @@ class Postgres(Dialect):
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
# https://www.postgresql.org/docs/current/sql-createtable.html
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
SUPPORTS_UNLOGGED_TABLES = True
LIKE_PROPERTY_INSIDE_SCHEMA = True
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -412,11 +422,14 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"),
exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.JSONPathKey: json_path_key_only_name,
exp.JSONPathRoot: lambda *_: "",
exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this),
exp.LastDay: no_last_day_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
@ -443,7 +456,8 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),

View file

@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@ -150,7 +149,7 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
return expression
def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str:
"""
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
@ -292,6 +291,7 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
LIMIT_ONLY_LITERALS = True
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@ -324,12 +324,18 @@ class Presto(Dialect):
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
exp.BitwiseAnd: lambda self,
e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseLeftShift: lambda self,
e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseOr: lambda self,
e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self,
e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self,
e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
@ -344,7 +350,8 @@ class Presto(Dialect):
"DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateToDi: lambda self,
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "DAY"),
@ -352,12 +359,14 @@ class Presto(Dialect):
e.this,
),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.DiToDate: lambda self,
e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GetPath: path_to_jsonpath(),
exp.FirstValue: _first_last_sql,
exp.FromTimeZone: lambda self,
e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@ -368,6 +377,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
exp.LastValue: _first_last_sql,
exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
@ -394,26 +404,33 @@ class Presto(Dialect):
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StrToUnix: lambda self,
e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToUnix: lambda self,
e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self,
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self,
e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToStr: lambda self,
e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.UnixToTimeStr: lambda self,
e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]),
exp.WithinGroup: transforms.preprocess(

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
json_extract_segments,
no_tablesample_sql,
rename_func,
)
@ -20,10 +21,6 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _parse_delta(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
@ -62,6 +59,7 @@ class Redshift(Postgres):
"DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
"DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
"DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
@ -69,6 +67,7 @@ class Redshift(Postgres):
NO_PAREN_FUNCTION_PARSERS = {
**Postgres.Parser.NO_PAREN_FUNCTION_PARSERS,
"APPROXIMATE": lambda self: self._parse_approximate_count(),
"SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True),
}
def _parse_table(
@ -77,6 +76,7 @@ class Redshift(Postgres):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
unpivot = self._match(TokenType.UNPIVOT)
@ -85,6 +85,7 @@ class Redshift(Postgres):
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
is_db_reference=is_db_reference,
)
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
@ -153,7 +154,6 @@ class Redshift(Postgres):
**Postgres.Tokenizer.KEYWORDS,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
@ -180,31 +180,29 @@ class Redshift(Postgres):
exp.DataType.Type.VARBINARY: "VARBYTE",
}
PROPERTIES_LOCATION = {
**Postgres.Generator.PROPERTIES_LOCATION,
exp.LikeProperty: exp.Properties.Location.POST_WITH,
}
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.ConcatWs: concat_ws_to_dpipe_sql,
exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.ApproxDistinct: lambda self,
e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})",
exp.CurrentTimestamp: lambda self, e: (
"SYSDATE" if e.args.get("transaction") else "GETDATE()"
),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.SortKeyProperty: lambda self,
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
@ -228,6 +226,13 @@ class Redshift(Postgres):
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type(exp.DataType.Type.JSON):
# Redshift doesn't support a JSON type, so casting to it is treated as a noop
return self.sql(expression, "this")
return super().cast_sql(expression, safe_prefix=safe_prefix)
def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean

View file

@ -21,19 +21,13 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
from sqlglot._typing import E
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
return s[1:].isdigit()
return s.isdigit()
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
@ -53,7 +47,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
return exp.TimeStrToTime.from_arg_list(args)
if first_arg.is_string:
if _check_int(first_arg.this):
if is_int(first_arg.this):
# case: <integer>
return exp.UnixToTime.from_arg_list(args)
@ -241,7 +235,6 @@ DATE_PART_MAPPING = {
"NSECOND": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"NANOSECS": "NANOSECOND",
"NSECONDS": "NANOSECOND",
"EPOCH": "EPOCH_SECOND",
"EPOCH_SECONDS": "EPOCH_SECOND",
"EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
@ -291,7 +284,9 @@ def _parse_colon_get_path(
path = exp.Literal.string(path.sql(dialect="snowflake"))
# The extraction operator : is left-associative
this = self.expression(exp.GetPath, this=this, expression=path)
this = self.expression(
exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
)
if target_type:
this = exp.cast(this, target_type)
@ -411,6 +406,9 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
"GET_PATH": lambda args, dialect: exp.JSONExtract(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
),
"IFF": exp.If.from_arg_list,
"LAST_DAY": lambda args: exp.LastDay(
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
@ -474,6 +472,8 @@ class Snowflake(Dialect):
"TERSE SCHEMAS": _show_parser("SCHEMAS"),
"OBJECTS": _show_parser("OBJECTS"),
"TERSE OBJECTS": _show_parser("OBJECTS"),
"TABLES": _show_parser("TABLES"),
"TERSE TABLES": _show_parser("TABLES"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
@ -534,7 +534,9 @@ class Snowflake(Dialect):
return table
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False
) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
@ -550,7 +552,9 @@ class Snowflake(Dialect):
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
file_format = self._parse_string() or super()._parse_table_parts(
is_db_reference=is_db_reference
)
elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
else:
@ -560,7 +564,7 @@ class Snowflake(Dialect):
table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
else:
table = super()._parse_table_parts(schema=schema)
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
return self._parse_at_before(table)
@ -587,6 +591,8 @@ class Snowflake(Dialect):
# which is syntactically valid but has no effect on the output
terse = self._tokens[self._index - 2].text.upper() == "TERSE"
history = self._match_text_seq("HISTORY")
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
@ -597,7 +603,7 @@ class Snowflake(Dialect):
if self._curr:
scope = self._parse_table_parts()
elif self._curr:
scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE"
scope = self._parse_table_parts()
return self.expression(
@ -605,6 +611,7 @@ class Snowflake(Dialect):
**{
"terse": terse,
"this": this,
"history": history,
"like": like,
"scope": scope,
"scope_kind": scope_kind,
@ -715,8 +722,10 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
exp.JSONExtract: rename_func("GET_PATH"),
exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.JSONPathRoot: lambda *_: "",
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@ -745,7 +754,8 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
@ -771,6 +781,12 @@ class Snowflake(Dialect):
exp.Xor: rename_func("BOOLXOR"),
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
@ -841,6 +857,7 @@ class Snowflake(Dialect):
def show_sql(self, expression: exp.Show) -> str:
terse = "TERSE " if expression.args.get("terse") else ""
history = " HISTORY" if expression.args.get("history") else ""
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
@ -861,9 +878,7 @@ class Snowflake(Dialect):
if from_:
from_ = f" FROM {from_}"
return (
f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
)
return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to

View file

@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.hive import _parse_ignore_nulls
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
@ -45,9 +46,7 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": lambda args: exp.AnyValue(
this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
),
"ANY_VALUE": _parse_ignore_nulls(exp.AnyValue),
"DATEDIFF": _parse_datediff,
}

View file

@ -187,8 +187,10 @@ class Spark2(Hive):
TRANSFORMS = {
**Hive.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.AtTimeZone: lambda self,
e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.DateFromParts: rename_func("MAKE_DATE"),
@ -198,7 +200,8 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.FromTimeZone: lambda self,
e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@ -212,7 +215,8 @@ class Spark2(Hive):
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToTime: lambda self,
e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),

View file

@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
count_if_to_sum,
@ -28,6 +27,12 @@ def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
return self.func("DATE", expression.this, modifier)
def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> str:
if expression.expressions:
return self.function_fallback_sql(expression)
return arrow_json_extract_sql(self, expression)
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
@ -85,6 +90,14 @@ class SQLite(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -120,10 +133,8 @@ class SQLite(Dialect):
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
@ -141,11 +152,18 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
# SQLite doesn't generally support CREATE TABLE .. properties
# https://www.sqlite.org/lang_createtable.html
PROPERTIES_LOCATION = {
k: exp.Properties.Location.UNSUPPORTED
for k, v in generator.Generator.PROPERTIES_LOCATION.items()
prop: exp.Properties.Location.UNSUPPORTED
for prop in generator.Generator.PROPERTIES_LOCATION
}
# There are a few exceptions (e.g. temporary tables) which are supported or
# can be transpiled to SQLite, so we explicitly override them accordingly
PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA
PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:

View file

@ -44,12 +44,14 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StrToUnix: lambda self,
e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToStr: lambda self,
e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}

View file

@ -200,7 +200,8 @@ class Teradata(Dialect):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.StrToDate: lambda self,
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}

View file

@ -11,9 +11,16 @@ class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.ArraySum: lambda self,
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]

View file

@ -14,7 +14,6 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
path_to_jsonpath,
rename_func,
timestrtotime_sql,
trim_sql,
@ -266,13 +265,32 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
def _parse_len(args: t.List) -> exp.Length:
this = seq_get(args, 0)
def _parse_as_text(
klass: t.Type[exp.Expression],
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
this = seq_get(args, 0)
if this and not this.is_string:
this = exp.cast(this, exp.DataType.Type.TEXT)
if this and not this.is_string:
this = exp.cast(this, exp.DataType.Type.TEXT)
return exp.Length(this=this)
expression = seq_get(args, 1)
kwargs = {"this": this}
if expression:
kwargs["expression"] = expression
return klass(**kwargs)
return _parse
def _json_extract_sql(
self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
json_query = rename_func("JSON_QUERY")(self, expression)
json_value = rename_func("JSON_VALUE")(self, expression)
return self.func("ISNULL", json_query, json_value)
class TSQL(Dialect):
@ -441,8 +459,11 @@ class TSQL(Dialect):
"HASHBYTES": _parse_hashbytes,
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"LEN": _parse_len,
"JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
"JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
"LEN": _parse_as_text(exp.Length),
"LEFT": _parse_as_text(exp.Left),
"RIGHT": _parse_as_text(exp.Right),
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@ -677,6 +698,7 @@ class TSQL(Dialect):
SUPPORTS_SINGLE_ARG_CONCAT = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@ -688,6 +710,12 @@ class TSQL(Dialect):
exp.Update,
}
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
@ -712,9 +740,10 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: _json_extract_sql,
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
@ -831,15 +860,21 @@ class TSQL(Dialect):
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
like_property = expression.find(exp.LikeProperty)
if like_property:
ctas_expression = like_property.this
else:
ctas_expression = expression.expression
table = expression.find(exp.Table)
# Convert CTAS statement to SELECT .. INTO ..
if kind == "TABLE" and expression.expression:
ctas_with = expression.expression.args.get("with")
if kind == "TABLE" and ctas_expression:
ctas_with = ctas_expression.args.get("with")
if ctas_with:
ctas_with = ctas_with.pop()
subquery = expression.expression
subquery = ctas_expression
if isinstance(subquery, exp.Subqueryable):
subquery = subquery.subquery()
@ -847,6 +882,9 @@ class TSQL(Dialect):
select_into.set("into", exp.Into(this=table))
select_into.set("with", ctas_with)
if like_property:
select_into.limit(0, copy=False)
sql = self.sql(select_into)
if exists:
@ -937,9 +975,19 @@ class TSQL(Dialect):
return f"CONSTRAINT {this} {expressions}"
def length_sql(self, expression: exp.Length) -> str:
return self._uncast_text(expression, "LEN")
def right_sql(self, expression: exp.Right) -> str:
return self._uncast_text(expression, "RIGHT")
def left_sql(self, expression: exp.Left) -> str:
return self._uncast_text(expression, "LEFT")
def _uncast_text(self, expression: exp.Expression, name: str) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
this_sql = self.sql(this, "this")
else:
this_sql = self.sql(this)
return self.func("LEN", this_sql)
expression_sql = self.sql(expression, "expression")
return self.func(name, this_sql, expression_sql if expression_sql else None)

View file

@ -10,7 +10,6 @@ import logging
import time
import typing as t
from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
from sqlglot.executor.table import Table, ensure_tables
@ -23,7 +22,6 @@ logger = logging.getLogger("sqlglot")
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlglot.executor.table import Tables
from sqlglot.expressions import Expression
from sqlglot.schema import Schema

View file

@ -44,9 +44,9 @@ class Context:
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
raise Exception("Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
raise Exception("Rows are different.")
return self._table

View file

@ -6,7 +6,7 @@ from functools import wraps
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import PYTHON_VERSION
from sqlglot.helper import PYTHON_VERSION, is_int, seq_get
class reverse_key:
@ -143,6 +143,22 @@ def arrayjoin(this, expression, null=None):
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)
@null_if_any("this", "expression")
def jsonextract(this, expression):
for path_segment in expression:
if isinstance(this, dict):
this = this.get(path_segment)
elif isinstance(this, list) and is_int(path_segment):
this = seq_get(this, int(path_segment))
else:
raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.")
if this is None:
break
return this
ENV = {
"exp": exp,
# aggs
@ -175,12 +191,12 @@ ENV = {
"DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GETPATH": null_if_any(lambda this, e: this.get(e)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,
"JSONEXTRACT": jsonextract,
"LEFT": null_if_any(lambda this, e: this[:e]),
"LIKE": null_if_any(
lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))

View file

@ -9,7 +9,7 @@ from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import RowReader, Table
from sqlglot.helper import csv_reader, subclasses
from sqlglot.helper import csv_reader, ensure_list, subclasses
class PythonExecutor:
@ -368,7 +368,7 @@ def _rename(self, e):
if isinstance(e, exp.Func) and e.is_var_len_args:
*head, tail = values
return self.func(e.key, *head, *tail)
return self.func(e.key, *head, *ensure_list(tail))
return self.func(e.key, *values)
except Exception as ex:
@ -429,18 +429,24 @@ class Python(Dialect):
exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Column: lambda self,
e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
exp.Concat: lambda self, e: self.func(
"SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions
),
exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})",
exp.Div: _div_sql,
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Extract: lambda self,
e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self,
e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
exp.Is: lambda self, e: (
self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is")
),
exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]",
exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'",
exp.JSONPathSubscript: lambda self, e: f"'{e.this}'",
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",

View file

@ -29,6 +29,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_collection,
ensure_list,
is_int,
seq_get,
subclasses,
)
@ -175,13 +176,7 @@ class Expression(metaclass=_Expression):
"""
Checks whether a Literal expression is an integer.
"""
if self.is_number:
try:
int(self.name)
return True
except ValueError:
pass
return False
return self.is_number and is_int(self.name)
@property
def is_star(self) -> bool:
@ -493,8 +488,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
if type(node) is not self.__class__:
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
def __str__(self) -> str:
@ -553,10 +548,12 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
def replace(self, expression: E) -> E: ...
def replace(self, expression: E) -> E:
...
@t.overload
def replace(self, expression: None) -> None: ...
def replace(self, expression: None) -> None:
...
def replace(self, expression):
"""
@ -610,7 +607,8 @@ class Expression(metaclass=_Expression):
>>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
'SELECT x, z FROM y'
"""
assert isinstance(self, type_)
if not isinstance(self, type_):
raise AssertionError(f"{self} is not {type_}.")
return self
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
@ -1133,6 +1131,7 @@ class SetItem(Expression):
class Show(Expression):
arg_types = {
"this": True,
"history": False,
"terse": False,
"target": False,
"offset": False,
@ -1676,7 +1675,6 @@ class Index(Expression):
"amp": False, # teradata
"include": False,
"partition_by": False, # teradata
"where": False, # postgres partial indexes
}
@ -2573,7 +2571,7 @@ class HistoricalData(Expression):
class Table(Expression):
arg_types = {
"this": True,
"this": False,
"alias": False,
"db": False,
"catalog": False,
@ -3664,6 +3662,7 @@ class DataType(Expression):
BINARY = auto()
BIT = auto()
BOOLEAN = auto()
BPCHAR = auto()
CHAR = auto()
DATE = auto()
DATE32 = auto()
@ -3805,6 +3804,7 @@ class DataType(Expression):
dtype: DATA_TYPE,
dialect: DialectType = None,
udt: bool = False,
copy: bool = True,
**kwargs,
) -> DataType:
"""
@ -3815,7 +3815,8 @@ class DataType(Expression):
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
kawrgs: additional arguments to pass in the constructor of DataType.
copy: whether or not to copy the data type.
kwargs: additional arguments to pass in the constructor of DataType.
Returns:
The constructed DataType object.
@ -3837,7 +3838,7 @@ class DataType(Expression):
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
return dtype
return maybe_copy(dtype, copy)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
@ -3855,7 +3856,7 @@ class DataType(Expression):
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
for dtype in dtypes:
other = DataType.build(dtype, udt=True)
other = DataType.build(dtype, copy=False, udt=True)
if (
other.expressions
@ -4001,7 +4002,7 @@ class Dot(Binary):
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
if len(expressions) < 2:
raise ValueError(f"Dot requires >= 2 expressions.")
raise ValueError("Dot requires >= 2 expressions.")
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
@ -4128,10 +4129,6 @@ class Sub(Binary):
pass
class ArrayOverlaps(Binary):
pass
# Unary Expressions
# (NOT a)
class Unary(Condition):
@ -4469,6 +4466,10 @@ class ArrayJoin(Func):
arg_types = {"this": True, "expression": True, "null": False}
class ArrayOverlaps(Binary, Func):
pass
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
@ -4490,15 +4491,37 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
arg_types = {"this": True, "having": False, "max": False}
class First(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Lag(AggFunc):
arg_types = {"this": True, "offset": False, "default": False}
class Last(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Lead(AggFunc):
arg_types = {"this": True, "offset": False, "default": False}
# some dialects have a distinction between first and first_value, usually first is an aggregate func
# and first_value is a window func
class First(AggFunc):
pass
class Last(AggFunc):
pass
class FirstValue(AggFunc):
pass
class LastValue(AggFunc):
pass
class NthValue(AggFunc):
arg_types = {"this": True, "offset": True}
class Case(Func):
@ -4611,7 +4634,7 @@ class CurrentTime(Func):
class CurrentTimestamp(Func):
arg_types = {"this": False}
arg_types = {"this": False, "transaction": False}
class CurrentUser(Func):
@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit):
class TimestampDiff(Func, TimeUnit):
_sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"]
arg_types = {"this": True, "expression": True, "unit": False}
@ -4857,6 +4881,59 @@ class IsInf(Func):
_sql_names = ["IS_INF", "ISINF"]
class JSONPath(Expression):
arg_types = {"expressions": True}
@property
def output_name(self) -> str:
last_segment = self.expressions[-1].this
return last_segment if isinstance(last_segment, str) else ""
class JSONPathPart(Expression):
arg_types = {}
class JSONPathFilter(JSONPathPart):
arg_types = {"this": True}
class JSONPathKey(JSONPathPart):
arg_types = {"this": True}
class JSONPathRecursive(JSONPathPart):
arg_types = {"this": False}
class JSONPathRoot(JSONPathPart):
pass
class JSONPathScript(JSONPathPart):
arg_types = {"this": True}
class JSONPathSlice(JSONPathPart):
arg_types = {"start": False, "end": False, "step": False}
class JSONPathSelector(JSONPathPart):
arg_types = {"this": True}
class JSONPathSubscript(JSONPathPart):
arg_types = {"this": True}
class JSONPathUnion(JSONPathPart):
arg_types = {"expressions": True}
class JSONPathWildcard(JSONPathPart):
pass
class FormatJson(Expression):
pass
@ -4940,18 +5017,30 @@ class JSONBContains(Binary):
class JSONExtract(Binary, Func):
arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT"]
is_var_len_args = True
@property
def output_name(self) -> str:
return self.expression.output_name if not self.expressions else ""
class JSONExtractScalar(JSONExtract):
class JSONExtractScalar(Binary, Func):
arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT_SCALAR"]
is_var_len_args = True
@property
def output_name(self) -> str:
return self.expression.output_name
class JSONBExtract(JSONExtract):
class JSONBExtract(Binary, Func):
_sql_names = ["JSONB_EXTRACT"]
class JSONBExtractScalar(JSONExtract):
class JSONBExtractScalar(Binary, Func):
_sql_names = ["JSONB_EXTRACT_SCALAR"]
@ -4972,15 +5061,6 @@ class ParseJSON(Func):
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/get_path
class GetPath(Func):
arg_types = {"this": True, "expression": True}
@property
def output_name(self) -> str:
return self.expression.output_name
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -5476,6 +5556,8 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
# Helpers
@t.overload
@ -5487,7 +5569,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E: ...
) -> E:
...
@t.overload
@ -5499,7 +5582,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E: ...
) -> E:
...
def maybe_parse(
@ -5539,7 +5623,7 @@ def maybe_parse(
return sql_or_expression
if sql_or_expression is None:
raise ParseError(f"SQL cannot be None")
raise ParseError("SQL cannot be None")
import sqlglot
@ -5551,11 +5635,13 @@ def maybe_parse(
@t.overload
def maybe_copy(instance: None, copy: bool = True) -> None: ...
def maybe_copy(instance: None, copy: bool = True) -> None:
...
@t.overload
def maybe_copy(instance: E, copy: bool = True) -> E: ...
def maybe_copy(instance: E, copy: bool = True) -> E:
...
def maybe_copy(instance, copy=True):
@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
return Paren(this=maybe_parse(expression, copy=copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
) -> Identifier: ...
) -> Identifier:
...
def to_identifier(name, quoted=None, copy=True):
@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@t.overload
def to_table(sql_path: None, **kwargs) -> None: ...
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(
@ -6460,7 +6550,7 @@ def column(
return this
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
Args:
expression: The expression to cast.
to: The datatype to cast to.
copy: Whether or not to copy the supplied expressions.
Returns:
The new Cast instance.
"""
expression = maybe_parse(expression, **opts)
data_type = DataType.build(to, **opts)
expression = maybe_parse(expression, copy=copy, **opts)
data_type = DataType.build(to, copy=copy, **opts)
expression = Cast(this=expression, to=data_type)
expression.type = data_type
return expression

View file

@ -9,6 +9,7 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@ -21,7 +22,18 @@ logger = logging.getLogger("sqlglot")
ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
class Generator:
class _Generator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
# Remove transforms that correspond to unsupported JSONPathPart expressions
for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS:
klass.TRANSFORMS.pop(part, None)
return klass
class Generator(metaclass=_Generator):
"""
Generator converts a given syntax tree to the corresponding SQL string.
@ -58,19 +70,23 @@ class Generator:
Default: True
"""
TRANSFORMS = {
TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
**JSON_PATH_PART_TRANSFORMS,
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda self,
e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self,
e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self,
e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
@ -85,29 +101,33 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NonClusteredColumnConstraint: lambda self,
e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnCommitProperty: lambda self,
e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.RemoteWithConnectionModelProperty: lambda self,
e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SetConfigProperty: lambda self, e: self.sql(e, "this"),
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlReadWriteProperty: lambda self, e: e.name,
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.SqlSecurityProperty: lambda self,
e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TemporaryProperty: lambda self, e: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE",
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
@ -117,6 +137,10 @@ class Generator:
# True: Full Support, None: No support, False: No support in window specifications
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
# Whether or not ignore nulls is inside the agg or outside.
# FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER
IGNORE_NULLS_IN_FUNC = False
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
@ -266,6 +290,24 @@ class Generator:
# Whether or not UNLOGGED tables can be created
SUPPORTS_UNLOGGED_TABLES = False
# Whether or not the CREATE TABLE LIKE statement is supported
SUPPORTS_CREATE_TABLE_LIKE = True
# Whether or not the LikeProperty needs to be specified inside of the schema clause
LIKE_PROPERTY_INSIDE_SCHEMA = False
# Whether or not the JSON extraction operators expect a value of type JSON
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
# Whether or not bracketed keys like ["foo"] are supported in JSON paths
JSON_PATH_BRACKETED_KEY_SUPPORTED = True
# Whether or not to escape keys using single quotes in JSON paths
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
# The JSONPathPart expressions supported by this dialect
SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy()
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -641,8 +683,6 @@ class Generator:
if callable(transform):
sql = transform(self, expression)
elif transform:
sql = transform
elif isinstance(expression, exp.Expression):
exp_handler_name = f"{expression.key}_sql"
@ -802,7 +842,7 @@ class Generator:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
return "PRIMARY KEY"
def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str:
this = self.sql(expression, "this")
@ -1218,9 +1258,21 @@ class Generator:
return f"{property_name}={self.sql(expression, 'this')}"
def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
if self.SUPPORTS_CREATE_TABLE_LIKE:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
like = f"LIKE {self.sql(expression, 'this')}{options}"
if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance(expression.parent, exp.Schema):
like = f"({like})"
return like
if expression.expressions:
self.unsupported("Transpilation of LIKE property options is unsupported")
select = exp.select("*").from_(expression.this).limit(0)
return f"AS {self.sql(select)}"
def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str:
no = "NO " if expression.args.get("no") else ""
@ -2367,6 +2419,31 @@ class Generator:
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
def jsonpath_sql(self, expression: exp.JSONPath) -> str:
path = self.expressions(expression, sep="", flat=True).lstrip(".")
return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str:
if isinstance(expression, exp.JSONPathPart):
transform = self.TRANSFORMS.get(expression.__class__)
if not callable(transform):
self.unsupported(f"Unsupported JSONPathPart type {expression.__class__.__name__}")
return ""
return transform(self, expression)
if isinstance(expression, int):
return str(expression)
if self.JSON_PATH_SINGLE_QUOTE_ESCAPE:
escaped = expression.replace("'", "\\'")
escaped = f"\\'{expression}\\'"
else:
escaped = expression.replace('"', '\\"')
escaped = f'"{escaped}"'
return escaped
def formatjson_sql(self, expression: exp.FormatJson) -> str:
return f"{self.sql(expression, 'this')} FORMAT JSON"
@ -2620,6 +2697,9 @@ class Generator:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
return self.func("CURRENT_TIMESTAMP", expression.this)
def collate_sql(self, expression: exp.Collate) -> str:
if self.COLLATE_IS_FUNC:
return self.function_fallback_sql(expression)
@ -2761,10 +2841,20 @@ class Generator:
return f"DISTINCT{this}{on}"
def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return f"{self.sql(expression, 'this')} IGNORE NULLS"
return self._embed_ignore_nulls(expression, "IGNORE NULLS")
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
return self._embed_ignore_nulls(expression, "RESPECT NULLS")
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC:
this = expression.find(exp.AggFunc)
if this:
sql = self.sql(this)
sql = sql[:-1] + f" {text})"
return sql
return f"{self.sql(expression, 'this')} {text}"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
@ -2935,7 +3025,7 @@ class Generator:
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
def text_width(self, args: t.Iterable) -> int:
@ -3279,6 +3369,22 @@ class Generator:
return self.func("LAST_DAY", expression.this)
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""
if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"
this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"
def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

View file

@ -53,11 +53,13 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
@t.overload
def ensure_list(value: t.Collection[T]) -> t.List[T]: ...
def ensure_list(value: t.Collection[T]) -> t.List[T]:
...
@t.overload
def ensure_list(value: T) -> t.List[T]: ...
def ensure_list(value: T) -> t.List[T]:
...
def ensure_list(value):
@ -79,11 +81,13 @@ def ensure_list(value):
@t.overload
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ...
def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
...
@t.overload
def ensure_collection(value: T) -> t.Collection[T]: ...
def ensure_collection(value: T) -> t.Collection[T]:
...
def ensure_collection(value):
@ -232,7 +236,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
for node, deps in tuple(dag.items()):
for dep in deps:
if not dep in dag:
if dep not in dag:
dag[dep] = set()
while dag:
@ -316,6 +320,14 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new
def is_int(text: str) -> bool:
try:
int(text)
return True
except ValueError:
return False
def name_sequence(prefix: str) -> t.Callable[[], str]:
"""Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
sequence = count()

View file

@ -2,8 +2,8 @@ from __future__ import annotations
import typing as t
import sqlglot.expressions as exp
from sqlglot.errors import ParseError
from sqlglot.expressions import SAFE_IDENTIFIER_RE
from sqlglot.tokens import Token, Tokenizer, TokenType
if t.TYPE_CHECKING:
@ -36,20 +36,8 @@ class JSONPathTokenizer(Tokenizer):
STRING_ESCAPES = ["\\"]
JSONPathNode = t.Dict[str, t.Any]
def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode:
node = {"kind": kind, **kwargs}
if value is not None:
node["value"] = value
return node
def parse(path: str) -> t.List[JSONPathNode]:
"""Takes in a JSONPath string and converts into a list of nodes."""
def parse(path: str) -> exp.JSONPath:
"""Takes in a JSON path string and parses it into a JSONPath expression."""
tokens = JSONPathTokenizer().tokenize(path)
size = len(tokens)
@ -89,7 +77,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
if token:
return token.text
if _match(TokenType.STAR):
return _node("wildcard")
return exp.JSONPathWildcard()
if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN):
script = _prev().text == "("
start = i
@ -100,9 +88,9 @@ def parse(path: str) -> t.List[JSONPathNode]:
if _curr() in (TokenType.R_BRACKET, None):
break
_advance()
return _node(
"script" if script else "filter", path[tokens[start].start : tokens[i].end]
)
expr_type = exp.JSONPathScript if script else exp.JSONPathFilter
return expr_type(this=path[tokens[start].start : tokens[i].end])
number = "-" if _match(TokenType.DASH) else ""
@ -112,6 +100,7 @@ def parse(path: str) -> t.List[JSONPathNode]:
if number:
return int(number)
return False
def _parse_slice() -> t.Any:
@ -121,9 +110,10 @@ def parse(path: str) -> t.List[JSONPathNode]:
if end is None and step is None:
return start
return _node("slice", start=start, end=end, step=step)
def _parse_bracket() -> JSONPathNode:
return exp.JSONPathSlice(start=start, end=end, step=step)
def _parse_bracket() -> exp.JSONPathPart:
literal = _parse_slice()
if isinstance(literal, str) or literal is not False:
@ -136,13 +126,15 @@ def parse(path: str) -> t.List[JSONPathNode]:
if len(indexes) == 1:
if isinstance(literal, str):
node = _node("key", indexes[0])
elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"):
node = _node("selector", indexes[0])
node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0])
elif isinstance(literal, exp.JSONPathPart) and isinstance(
literal, (exp.JSONPathScript, exp.JSONPathFilter)
):
node = exp.JSONPathSelector(this=indexes[0])
else:
node = _node("subscript", indexes[0])
node = exp.JSONPathSubscript(this=indexes[0])
else:
node = _node("union", indexes)
node = exp.JSONPathUnion(expressions=indexes)
else:
raise ParseError(_error("Cannot have empty segment"))
@ -150,66 +142,56 @@ def parse(path: str) -> t.List[JSONPathNode]:
return node
nodes = []
# We canonicalize the JSON path AST so that it always starts with a
# "root" element, so paths like "field" will be generated as "$.field"
_match(TokenType.DOLLAR)
expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
while _curr():
if _match(TokenType.DOLLAR):
nodes.append(_node("root"))
elif _match(TokenType.DOT):
if _match(TokenType.DOT) or _match(TokenType.COLON):
recursive = _prev().text == ".."
value = _match(TokenType.VAR) or _match(TokenType.STAR)
nodes.append(
_node("recursive" if recursive else "child", value=value.text if value else None)
)
if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
value: t.Optional[str | exp.JSONPathWildcard] = _prev().text
elif _match(TokenType.STAR):
value = exp.JSONPathWildcard()
else:
value = None
if recursive:
expressions.append(exp.JSONPathRecursive(this=value))
elif value:
expressions.append(exp.JSONPathKey(this=value))
else:
raise ParseError(_error("Expected key name or * after DOT"))
elif _match(TokenType.L_BRACKET):
nodes.append(_parse_bracket())
elif _match(TokenType.VAR):
nodes.append(_node("key", _prev().text))
expressions.append(_parse_bracket())
elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER):
expressions.append(exp.JSONPathKey(this=_prev().text))
elif _match(TokenType.STAR):
nodes.append(_node("wildcard"))
elif _match(TokenType.PARAMETER):
nodes.append(_node("current"))
expressions.append(exp.JSONPathWildcard())
else:
raise ParseError(_error(f"Unexpected {tokens[i].token_type}"))
return nodes
return exp.JSONPath(expressions=expressions)
MAPPING = {
"child": lambda n: f".{n['value']}" if n.get("value") is not None else "",
"filter": lambda n: f"?{n['value']}",
"key": lambda n: (
f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]'
),
"recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..",
"root": lambda _: "$",
"script": lambda n: f"({n['value']}",
"slice": lambda n: ":".join(
"" if p is False else generate([p])
for p in [n["start"], n["end"], n["step"]]
JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = {
exp.JSONPathFilter: lambda _, e: f"?{e.this}",
exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e),
exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}",
exp.JSONPathRoot: lambda *_: "$",
exp.JSONPathScript: lambda _, e: f"({e.this}",
exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]",
exp.JSONPathSlice: lambda self, e: ":".join(
"" if p is False else self.json_path_part(p)
for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")]
if p is not None
),
"selector": lambda n: f"[{generate([n['value']])}]",
"subscript": lambda n: f"[{generate([n['value']])}]",
"union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]",
"wildcard": lambda _: "*",
exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e),
exp.JSONPathUnion: lambda self,
e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]",
exp.JSONPathWildcard: lambda *_: "*",
}
def generate(
nodes: t.List[JSONPathNode],
mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None,
) -> str:
mapping = MAPPING if mapping is None else mapping
path = []
for node in nodes:
if isinstance(node, dict):
path.append(mapping[node["kind"]](node))
elif isinstance(node, str):
escaped = node.replace('"', '\\"')
path.append(f'"{escaped}"')
else:
path.append(str(node))
return "".join(path)
ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS)

View file

@ -1,3 +1,5 @@
# ruff: noqa: F401
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import (
Scope,

View file

@ -10,11 +10,13 @@ if t.TYPE_CHECKING:
@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ...
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
...
@t.overload
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ...
def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
def normalize_identifiers(expression, dialect=None):

View file

@ -8,10 +8,10 @@ from sqlglot.schema import ensure_schema
# Sentinel value that means an outer query selecting ALL columns
SELECT_ALL = object()
# Selection to use if selection list is empty
DEFAULT_SELECTION = lambda is_agg: alias(
exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_"
)
def default_selection(is_agg: bool) -> exp.Alias:
return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
@ -129,7 +129,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION(is_agg))
new_selections.append(default_selection(is_agg))
scope.expression.select(*new_selections, append=False, copy=False)

View file

@ -104,7 +104,6 @@ def simplify(
if root:
expression.replace(node)
return node
expression = while_changing(expression, _simplify)
@ -174,16 +173,20 @@ def simplify_not(expression):
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):
return exp.or_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
return exp.paren(
exp.or_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
)
if isinstance(condition, exp.Or):
return exp.and_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
return exp.paren(
exp.and_(
exp.not_(condition.left, copy=False),
exp.not_(condition.right, copy=False),
copy=False,
)
)
if is_null(condition):
return exp.null()
@ -490,7 +493,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
if not l.__class__ in INVERSE_OPS:
if l.__class__ not in INVERSE_OPS:
return expression
if r.is_number:
@ -714,8 +717,7 @@ def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
if not isinstance(expression, CONCATS) or (
# We can't reduce a CONCAT_WS call if we don't statically know the separator
isinstance(expression, exp.ConcatWs)
and not expression.expressions[0].is_string
isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
):
return expression

View file

@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
def _parser(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
)
if len(args) > 2 and expr_type is exp.JSONExtract:
expression.set("expressions", args[2:])
return expression
return _parser
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@ -102,6 +115,9 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
"JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract),
"JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar),
"JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar),
"LIKE": parse_like,
"LOG": parse_logarithm,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
@ -175,6 +191,7 @@ class Parser(metaclass=_Parser):
TokenType.NCHAR,
TokenType.VARCHAR,
TokenType.NVARCHAR,
TokenType.BPCHAR,
TokenType.TEXT,
TokenType.MEDIUMTEXT,
TokenType.LONGTEXT,
@ -295,6 +312,7 @@ class Parser(metaclass=_Parser):
TokenType.ASC,
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.BPCHAR,
TokenType.CACHE,
TokenType.CASE,
TokenType.COLLATE,
@ -531,12 +549,12 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
expression=path,
expression=self.dialect.to_json_path(path),
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
expression=path,
expression=self.dialect.to_json_path(path),
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser):
exp.Drop,
comments=start.comments,
exists=exists or self._parse_exists(),
this=self._parse_table(schema=True),
this=self._parse_table(
schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
),
kind=kind,
temporary=temporary,
materialized=materialized,
@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser):
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
table_parts = self._parse_table_parts(
schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
)
# exp.Properties.Location.POST_NAME
self._match(TokenType.COMMA)
@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
text += f" SHOW EMPTY MATCHES"
text += " SHOW EMPTY MATCHES"
elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"):
text += f" OMIT EMPTY MATCHES"
text += " OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
text += f" WITH UNMATCHED ROWS"
text += " WITH UNMATCHED ROWS"
rows = exp.var(text)
else:
rows = None
@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("AFTER", "MATCH", "SKIP"):
text = "AFTER MATCH SKIP"
if self._match_text_seq("PAST", "LAST", "ROW"):
text += f" PAST LAST ROW"
text += " PAST LAST ROW"
elif self._match_text_seq("TO", "NEXT", "ROW"):
text += f" TO NEXT ROW"
text += " TO NEXT ROW"
elif self._match_text_seq("TO", "FIRST"):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser):
or self._parse_placeholder()
)
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
catalog = None
db = None
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser):
db = table
table = self._parse_table_part(schema=schema) or ""
if not table:
if is_db_reference:
catalog = db
db = table
table = None
if not table and not is_db_reference:
self.raise_error(f"Expected table name but got {self._curr}")
if not db and is_db_reference:
self.raise_error(f"Expected database name but got {self._curr}")
return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser):
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser):
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this = t.cast(
exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
exp.Expression,
bracket
or self._parse_bracket(
self._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
),
)
if schema:
@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser):
identifier = allow_identifiers and self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
if identifier:
tokens = self.dialect.tokenize(identifier.name)
@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_column_reference()
return self._parse_column_ops(this) if this else self._parse_bracket(this)
def _parse_column_reference(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
elif not this:
return self._parse_bracket(this)
return self._parse_column_ops(this)
return this
def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser):
if not field:
self.raise_error("Expected type")
elif op and self._curr:
self._advance()
value = self._prev.text
field = (
exp.Literal.number(value)
if self._prev.token_type == TokenType.NUMBER
else exp.Literal.string(value)
)
field = self._parse_column_reference()
else:
field = self._parse_field(anonymous_func=True, any_token=True)
@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser):
options[kind] = action
return self.expression(
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
exp.ForeignKey,
expressions=expressions,
reference=reference,
**options, # type: ignore
)
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser):
return None
@t.overload
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject:
...
@t.overload
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg:
...
def _parse_json_object(self, agg=False):
star = self._parse_star()
@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser):
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
if isinstance(this, exp.AggFunc):
ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls)
if ignore_respect and ignore_respect is not this:
ignore_respect.replace(ignore_respect.this)
this = self.expression(ignore_respect.__class__, this=this)
this = self._parse_respect_or_ignore_nulls(this)
# bigquery select from window x AS (partition by ...)
@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser):
return True
@t.overload
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ...
def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression:
...
@t.overload
def _replace_columns_with_dots(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]: ...
) -> t.Optional[exp.Expression]:
...
def _replace_columns_with_dots(self, this):
if isinstance(this, exp.Dot):

View file

@ -125,6 +125,7 @@ class TokenType(AutoName):
NCHAR = auto()
VARCHAR = auto()
NVARCHAR = auto()
BPCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
@ -801,6 +802,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR": TokenType.NVARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
"BPCHAR": TokenType.BPCHAR,
"STR": TokenType.TEXT,
"STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT,
@ -1141,7 +1143,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
while not self._end and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK:
self._advance(alnum=True)
self._comments.append(self._text[comment_start_size:])
@ -1259,7 +1261,7 @@ class Tokenizer(metaclass=_Tokenizer):
if base:
try:
int(text, base)
except:
except Exception:
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)

View file

@ -485,8 +485,8 @@ def preprocess(
expression_type = type(expression)
expression = transforms[0](expression)
for t in transforms[1:]:
expression = t(expression)
for transform in transforms[1:]:
expression = transform(expression)
_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler: