Merging upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
ecb42ec17f
commit
63746a3e92
89 changed files with 35352 additions and 33081 deletions
|
@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
|
|||
T = t.TypeVar("T", bound=Expression)
|
||||
|
||||
|
||||
__version__ = "11.3.6"
|
||||
__version__ = "11.4.1"
|
||||
|
||||
pretty = False
|
||||
"""Whether to format generated SQL by default."""
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import sqlglot
|
||||
|
||||
|
@ -42,6 +45,12 @@ parser.add_argument(
|
|||
action="store_true",
|
||||
help="Parse and return the expression tree",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenize",
|
||||
dest="tokenize",
|
||||
action="store_true",
|
||||
help="Tokenize and return the tokens list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--error-level",
|
||||
dest="error_level",
|
||||
|
@ -57,7 +66,7 @@ error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
|||
sql = sys.stdin.read() if args.sql == "-" else args.sql
|
||||
|
||||
if args.parse:
|
||||
sqls = [
|
||||
objs: t.Union[t.List[str], t.List[sqlglot.tokens.Token]] = [
|
||||
repr(expression)
|
||||
for expression in sqlglot.parse(
|
||||
sql,
|
||||
|
@ -65,8 +74,10 @@ if args.parse:
|
|||
error_level=error_level,
|
||||
)
|
||||
]
|
||||
elif args.tokenize:
|
||||
objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql)
|
||||
else:
|
||||
sqls = sqlglot.transpile(
|
||||
objs = sqlglot.transpile(
|
||||
sql,
|
||||
read=args.read,
|
||||
write=args.write,
|
||||
|
@ -75,5 +86,5 @@ else:
|
|||
error_level=error_level,
|
||||
)
|
||||
|
||||
for sql in sqls:
|
||||
print(sql)
|
||||
for obj in objs:
|
||||
print(obj)
|
||||
|
|
|
@ -299,7 +299,7 @@ class DataFrame:
|
|||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
if optimize:
|
||||
select_expression = optimize_func(select_expression)
|
||||
select_expression = optimize_func(select_expression, identify="always")
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
||||
if expression_type == exp.Cache:
|
||||
|
|
|
@ -144,7 +144,6 @@ class BigQuery(Dialect):
|
|||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
|
||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
|
@ -194,7 +193,6 @@ class BigQuery(Dialect):
|
|||
NO_PAREN_FUNCTIONS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||
TokenType.CURRENT_TIME: exp.CurrentTime,
|
||||
}
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.helper import ensure_list, seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -40,7 +41,18 @@ class ClickHouse(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
|
||||
this=seq_get(args, 0),
|
||||
time=seq_get(args, 1),
|
||||
decay=seq_get(params, 0),
|
||||
),
|
||||
"MAP": parse_var_map,
|
||||
"HISTOGRAM": lambda params, args: exp.Histogram(
|
||||
this=seq_get(args, 0), bins=seq_get(params, 0)
|
||||
),
|
||||
"GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
|
||||
this=seq_get(args, 0), size=seq_get(params, 0)
|
||||
),
|
||||
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
|
||||
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
|
||||
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
|
||||
|
@ -113,22 +125,40 @@ class ClickHouse(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
|
||||
exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
|
||||
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
}
|
||||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def _param_args_sql(
|
||||
self, expression: exp.Expression, params_name: str, args_name: str
|
||||
self,
|
||||
expression: exp.Expression,
|
||||
param_names: str | t.List[str],
|
||||
arg_names: str | t.List[str],
|
||||
) -> str:
|
||||
params = self.format_args(self.expressions(expression, params_name))
|
||||
args = self.format_args(self.expressions(expression, args_name))
|
||||
params = self.format_args(
|
||||
*(
|
||||
arg
|
||||
for name in ensure_list(param_names)
|
||||
for arg in ensure_list(expression.args.get(name))
|
||||
)
|
||||
)
|
||||
args = self.format_args(
|
||||
*(
|
||||
arg
|
||||
for name in ensure_list(arg_names)
|
||||
for arg in ensure_list(expression.args.get(name))
|
||||
)
|
||||
)
|
||||
return f"({params})({args})"
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
|
|
|
@ -23,6 +23,7 @@ class Databricks(Spark):
|
|||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
|
||||
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot.generator import Generator
|
|||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.parser import Parser
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import Tokenizer
|
||||
from sqlglot.tokens import Token, Tokenizer
|
||||
from sqlglot.trie import new_trie
|
||||
|
||||
E = t.TypeVar("E", bound=exp.Expression)
|
||||
|
@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect):
|
|||
return expression
|
||||
|
||||
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
||||
return self.parser(**opts).parse(self.tokenize(sql), sql)
|
||||
|
||||
def parse_into(
|
||||
self, expression_type: exp.IntoType, sql: str, **opts
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
|
||||
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
|
||||
|
||||
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
|
||||
return self.generator(**opts).generate(expression)
|
||||
|
@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect):
|
|||
def transpile(self, sql: str, **opts) -> t.List[str]:
|
||||
return [self.generate(expression, **opts) for expression in self.parse(sql)]
|
||||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
return self.tokenizer.tokenize(sql)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
|
@ -385,6 +388,21 @@ def parse_date_delta(
|
|||
return inner_func
|
||||
|
||||
|
||||
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
|
||||
unit = seq_get(args, 0)
|
||||
this = seq_get(args, 1)
|
||||
|
||||
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
|
||||
return exp.DateTrunc(unit=unit, this=this)
|
||||
return exp.TimestampTrunc(this=this, unit=unit)
|
||||
|
||||
|
||||
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
||||
return self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
|
||||
)
|
||||
|
||||
|
||||
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
|
@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
|
|||
return rename_func(name)(self, expression)
|
||||
|
||||
|
||||
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
|
||||
cond = expression.this
|
||||
|
||||
if isinstance(expression.this, exp.Distinct):
|
||||
cond = expression.this.expressions[0]
|
||||
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
|
||||
|
||||
return self.func("sum", exp.func("if", cond, 1, 0))
|
||||
|
||||
|
||||
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
|
|
|
@ -97,6 +97,7 @@ class Drill(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
str_position_sql,
|
||||
str_to_time_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
|
@ -148,6 +149,9 @@ class DuckDB(Dialect):
|
|||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
|
@ -162,6 +166,7 @@ class DuckDB(Dialect):
|
|||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpExtract: _regexp_extract_sql,
|
||||
|
@ -175,6 +180,7 @@ class DuckDB(Dialect):
|
|||
exp.StrToTime: str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: _struct_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
|
||||
|
@ -186,6 +192,7 @@ class DuckDB(Dialect):
|
|||
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
|
@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self, expression):
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
modified_increment = (
|
||||
|
@ -47,7 +49,7 @@ def _add_date_sql(self, expression):
|
|||
return self.func(func, expression.this, modified_increment.this)
|
||||
|
||||
|
||||
def _date_diff_sql(self, expression):
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
|
||||
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
|
||||
|
@ -56,21 +58,21 @@ def _date_diff_sql(self, expression):
|
|||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _array_sort(self, expression):
|
||||
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
return f"SORT_ARRAY({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _property_sql(self, expression):
|
||||
def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
||||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix(self, expression):
|
||||
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -78,7 +80,7 @@ def _str_to_date(self, expression):
|
|||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
||||
def _str_to_time(self, expression):
|
||||
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -86,20 +88,22 @@ def _str_to_time(self, expression):
|
|||
return f"CAST({this} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def _time_format(self, expression):
|
||||
def _time_format(
|
||||
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
|
||||
) -> t.Optional[str]:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.time_format:
|
||||
return None
|
||||
return time_format
|
||||
|
||||
|
||||
def _time_to_str(self, expression):
|
||||
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
return f"DATE_FORMAT({this}, {time_format})"
|
||||
|
||||
|
||||
def _to_date_sql(self, expression):
|
||||
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Hive.time_format, Hive.date_format):
|
||||
|
@ -107,7 +111,7 @@ def _to_date_sql(self, expression):
|
|||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
def _unnest_to_explode_sql(self, expression):
|
||||
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
|
||||
unnest = expression.this
|
||||
if isinstance(unnest, exp.Unnest):
|
||||
alias = unnest.args.get("alias")
|
||||
|
@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression):
|
|||
exp.Lateral(
|
||||
this=udtf(this=expression),
|
||||
view=True,
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]),
|
||||
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
|
||||
)
|
||||
)
|
||||
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
|
||||
|
@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression):
|
|||
return self.join_sql(expression)
|
||||
|
||||
|
||||
def _index_sql(self, expression):
|
||||
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = self.sql(expression, "table")
|
||||
columns = self.sql(expression, "columns")
|
||||
|
@ -263,14 +267,15 @@ class Hive(Dialect):
|
|||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArraySize: rename_func("SIZE"),
|
||||
exp.ArraySort: _array_sort,
|
||||
|
@ -333,13 +338,19 @@ class Hive(Dialect):
|
|||
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
}
|
||||
|
||||
def with_properties(self, properties):
|
||||
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
|
||||
return self.func(
|
||||
"COLLECT_LIST",
|
||||
expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
|
||||
)
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(
|
||||
properties,
|
||||
prefix=self.seg("TBLPROPERTIES"),
|
||||
)
|
||||
|
||||
def datatype_sql(self, expression):
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
if (
|
||||
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
|
||||
and not expression.expressions
|
||||
|
|
|
@ -177,7 +177,7 @@ class MySQL(Dialect):
|
|||
"@@": TokenType.SESSION_PARAMETER,
|
||||
}
|
||||
|
||||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
|
||||
|
@ -211,7 +211,6 @@ class MySQL(Dialect):
|
|||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
|
@ -269,15 +268,12 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
SET_PARSERS = {
|
||||
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
|
||||
**parser.Parser.SET_PARSERS,
|
||||
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
|
||||
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
|
||||
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
|
||||
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
|
||||
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"NAMES": lambda self: self._parse_set_item_names(),
|
||||
"TRANSACTION": lambda self: self._parse_set_transaction(),
|
||||
}
|
||||
|
||||
PROFILE_TYPES = {
|
||||
|
@ -292,15 +288,6 @@ class MySQL(Dialect):
|
|||
"SWAPS",
|
||||
}
|
||||
|
||||
TRANSACTION_CHARACTERISTICS = {
|
||||
"ISOLATION LEVEL REPEATABLE READ",
|
||||
"ISOLATION LEVEL READ COMMITTED",
|
||||
"ISOLATION LEVEL READ UNCOMMITTED",
|
||||
"ISOLATION LEVEL SERIALIZABLE",
|
||||
"READ WRITE",
|
||||
"READ ONLY",
|
||||
}
|
||||
|
||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
|
@ -354,12 +341,6 @@ class MySQL(Dialect):
|
|||
**{"global": global_},
|
||||
)
|
||||
|
||||
def _parse_var_from_options(self, options):
|
||||
for option in options:
|
||||
if self._match_text_seq(*option.split(" ")):
|
||||
return exp.Var(this=option)
|
||||
return None
|
||||
|
||||
def _parse_oldstyle_limit(self):
|
||||
limit = None
|
||||
offset = None
|
||||
|
@ -372,30 +353,6 @@ class MySQL(Dialect):
|
|||
offset = parts[0]
|
||||
return offset, limit
|
||||
|
||||
def _default_parse_set_item(self):
|
||||
return self._parse_set_item_assignment(kind=None)
|
||||
|
||||
def _parse_set_item_assignment(self, kind):
|
||||
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
|
||||
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
if not self._match(TokenType.EQ):
|
||||
self.raise_error("Expected =")
|
||||
right = self._parse_statement() or self._parse_id_var()
|
||||
|
||||
this = self.expression(
|
||||
exp.EQ,
|
||||
this=left,
|
||||
expression=right,
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_set_item_charset(self, kind):
|
||||
this = self._parse_string() or self._parse_id_var()
|
||||
|
||||
|
@ -418,18 +375,6 @@ class MySQL(Dialect):
|
|||
kind="NAMES",
|
||||
)
|
||||
|
||||
def _parse_set_transaction(self, global_=False):
|
||||
self._match_text_seq("TRANSACTION")
|
||||
characteristics = self._parse_csv(
|
||||
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
|
||||
)
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
expressions=characteristics,
|
||||
kind="TRANSACTION",
|
||||
**{"global": global_},
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
@ -523,16 +468,3 @@ class MySQL(Dialect):
|
|||
limit_offset = f"{offset}, {limit}" if offset else limit
|
||||
return f" LIMIT {limit_offset}"
|
||||
return ""
|
||||
|
||||
def setitem_sql(self, expression):
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} " if kind else ""
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression)
|
||||
collate = self.sql(expression, "collate")
|
||||
collate = f" COLLATE {collate}" if collate else ""
|
||||
global_ = "GLOBAL " if expression.args.get("global") else ""
|
||||
return f"{global_}{kind}{this}{expressions}{collate}"
|
||||
|
||||
def set_sql(self, expression):
|
||||
return f"SET {self.expressions(expression)}"
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
timestamptrunc_sql,
|
||||
trim_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -34,7 +35,7 @@ def _date_add_sql(kind):
|
|||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit")
|
||||
unit = expression.args.get("unit")
|
||||
expression = simplify(expression.args["expression"])
|
||||
|
||||
if not isinstance(expression, exp.Literal):
|
||||
|
@ -92,8 +93,7 @@ def _string_agg_sql(self, expression):
|
|||
this = expression.this
|
||||
if isinstance(this, exp.Order):
|
||||
if this.this:
|
||||
this = this.this
|
||||
this.pop()
|
||||
this = this.this.pop()
|
||||
order = self.sql(expression.this) # Order has a leading space
|
||||
|
||||
return f"STRING_AGG({self.format_args(this, separator)}{order})"
|
||||
|
@ -256,6 +256,9 @@ class Postgres(Dialect):
|
|||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
}
|
||||
|
||||
BITWISE = {
|
||||
|
@ -311,6 +314,7 @@ class Postgres(Dialect):
|
|||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Min: min_or_least,
|
||||
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
|
||||
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
|
||||
|
@ -320,6 +324,7 @@ class Postgres(Dialect):
|
|||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Substring: _substring_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
|
|
|
@ -3,12 +3,14 @@ from __future__ import annotations
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
date_trunc_to_time,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
no_ilike_sql,
|
||||
no_safe_divide_sql,
|
||||
rename_func,
|
||||
struct_extract_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
@ -98,10 +100,16 @@ def _ts_or_ds_to_date_sql(self, expression):
|
|||
|
||||
|
||||
def _ts_or_ds_add_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
e = self.sql(expression, "expression")
|
||||
unit = self.sql(expression, "unit") or "'day'"
|
||||
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
||||
return self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(expression.text("unit") or "day"),
|
||||
expression.expression,
|
||||
self.func(
|
||||
"DATE_PARSE",
|
||||
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
|
||||
Presto.date_format,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _sequence_sql(self, expression):
|
||||
|
@ -195,6 +203,7 @@ class Presto(Dialect):
|
|||
),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"FROM_UNIXTIME": _from_unixtime,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
|
@ -237,6 +246,7 @@ class Presto(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
|
@ -250,8 +260,12 @@ class Presto(Dialect):
|
|||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
|
||||
exp.Decode: _decode_sql,
|
||||
|
@ -265,6 +279,7 @@ class Presto(Dialect):
|
|||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
|
@ -277,6 +292,7 @@ class Presto(Dialect):
|
|||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||
|
|
|
@ -20,6 +20,11 @@ class Redshift(Postgres):
|
|||
class Parser(Postgres.Parser):
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS, # type: ignore
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -76,13 +81,16 @@ class Redshift(Postgres):
|
|||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
|
||||
"DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
}
|
||||
|
||||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
|
|
|
@ -5,11 +5,13 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
date_trunc_to_time,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
min_or_least,
|
||||
rename_func,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
var_map_sql,
|
||||
|
@ -176,6 +178,7 @@ class Snowflake(Dialect):
|
|||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
|
||||
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
|
@ -186,10 +189,6 @@ class Snowflake(Dialect):
|
|||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
),
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
|
||||
this=seq_get(args, 1),
|
||||
),
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
"DIV0": _div0_to_if,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
|
@ -280,6 +279,8 @@ class Snowflake(Dialect):
|
|||
exp.DataType: _datatype_sql,
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Matches: rename_func("DECODE"),
|
||||
|
@ -287,6 +288,7 @@ class Snowflake(Dialect):
|
|||
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
|
|
|
@ -157,6 +157,7 @@ class Spark(Hive):
|
|||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
count_if_to_sum,
|
||||
no_ilike_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
|
@ -13,23 +14,6 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def _group_concat_sql(self, expression):
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
if distinct:
|
||||
this = distinct.expressions[0]
|
||||
distinct = "DISTINCT "
|
||||
|
||||
if isinstance(expression.this, exp.Order):
|
||||
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
|
||||
if expression.this.this and not distinct:
|
||||
this = expression.this.this
|
||||
|
||||
separator = expression.args.get("separator")
|
||||
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
|
||||
|
||||
|
||||
def _date_add_sql(self, expression):
|
||||
modifier = expression.expression
|
||||
modifier = expression.name if modifier.is_string else self.sql(modifier)
|
||||
|
@ -78,20 +62,32 @@ class SQLite(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.ELIMINATE_QUALIFY, # type: ignore
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
exp.CurrentTime: lambda *_: "CURRENT_TIME",
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql,
|
||||
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.Levenshtein: rename_func("EDITDIST3"),
|
||||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.GroupConcat: _group_concat_sql,
|
||||
}
|
||||
|
||||
def cast_sql(self, expression: exp.Cast) -> str:
|
||||
if expression.to.this == exp.DataType.Type.DATE:
|
||||
return self.func("DATE", expression.this)
|
||||
|
||||
return super().cast_sql(expression)
|
||||
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = expression.args.get("unit")
|
||||
unit = unit.name.upper() if unit else "DAY"
|
||||
|
@ -119,16 +115,32 @@ class SQLite(Dialect):
|
|||
|
||||
return f"CAST({sql} AS INTEGER)"
|
||||
|
||||
def fetch_sql(self, expression):
|
||||
def fetch_sql(self, expression: exp.Fetch) -> str:
|
||||
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
|
||||
|
||||
def least_sql(self, expression):
|
||||
# https://www.sqlite.org/lang_aggfunc.html#group_concat
|
||||
def groupconcat_sql(self, expression):
|
||||
this = expression.this
|
||||
distinct = expression.find(exp.Distinct)
|
||||
if distinct:
|
||||
this = distinct.expressions[0]
|
||||
distinct = "DISTINCT "
|
||||
|
||||
if isinstance(expression.this, exp.Order):
|
||||
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
|
||||
if expression.this.this and not distinct:
|
||||
this = expression.this.this
|
||||
|
||||
separator = expression.args.get("separator")
|
||||
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
|
||||
|
||||
def least_sql(self, expression: exp.Least) -> str:
|
||||
if len(expression.expressions) > 1:
|
||||
return rename_func("MIN")(self, expression)
|
||||
|
||||
return self.expressions(expression)
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
this = expression.this
|
||||
this = f" {this}" if this else ""
|
||||
return f"BEGIN{this} TRANSACTION"
|
||||
|
|
|
@ -3,9 +3,18 @@ from __future__ import annotations
|
|||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
class StarRocks(MySQL):
|
||||
class Parser(MySQL.Parser): # type: ignore
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
}
|
||||
|
||||
class Generator(MySQL.Generator): # type: ignore
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
@ -20,6 +29,9 @@ class StarRocks(MySQL):
|
|||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.DateDiff: rename_func("DATEDIFF"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
|
||||
),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
|
|
|
@ -117,14 +117,12 @@ def _string_agg_sql(self, e):
|
|||
if distinct:
|
||||
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
|
||||
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
|
||||
this = distinct.expressions[0]
|
||||
distinct.pop()
|
||||
this = distinct.pop().expressions[0]
|
||||
|
||||
order = ""
|
||||
if isinstance(e.this, exp.Order):
|
||||
if e.this.this:
|
||||
this = e.this.this
|
||||
e.this.this.pop()
|
||||
this = e.this.this.pop()
|
||||
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
|
||||
|
||||
separator = e.args.get("separator") or exp.Literal.string(",")
|
||||
|
|
|
@ -301,7 +301,7 @@ class Expression(metaclass=_Expression):
|
|||
the specified types.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
expression_types: the expression type(s) to match.
|
||||
|
||||
Returns:
|
||||
The node which matches the criteria or None if no such node was found.
|
||||
|
@ -314,7 +314,7 @@ class Expression(metaclass=_Expression):
|
|||
yields those that match at least one of the specified expression types.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
expression_types: the expression type(s) to match.
|
||||
|
||||
Returns:
|
||||
The generator object.
|
||||
|
@ -328,7 +328,7 @@ class Expression(metaclass=_Expression):
|
|||
Returns a nearest parent matching expression_types.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
expression_types: the expression type(s) to match.
|
||||
|
||||
Returns:
|
||||
The parent node.
|
||||
|
@ -336,8 +336,7 @@ class Expression(metaclass=_Expression):
|
|||
ancestor = self.parent
|
||||
while ancestor and not isinstance(ancestor, expression_types):
|
||||
ancestor = ancestor.parent
|
||||
# ignore type because mypy doesn't know that we're checking type in the loop
|
||||
return ancestor # type: ignore[return-value]
|
||||
return t.cast(E, ancestor)
|
||||
|
||||
@property
|
||||
def parent_select(self):
|
||||
|
@ -549,8 +548,12 @@ class Expression(metaclass=_Expression):
|
|||
def pop(self):
|
||||
"""
|
||||
Remove this expression from its AST.
|
||||
|
||||
Returns:
|
||||
The popped expression.
|
||||
"""
|
||||
self.replace(None)
|
||||
return self
|
||||
|
||||
def assert_is(self, type_):
|
||||
"""
|
||||
|
@ -626,6 +629,7 @@ IntoType = t.Union[
|
|||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
ExpOrStr = t.Union[str, Expression]
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
|
@ -809,7 +813,7 @@ class Describe(Expression):
|
|||
|
||||
|
||||
class Set(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
arg_types = {"expressions": False}
|
||||
|
||||
|
||||
class SetItem(Expression):
|
||||
|
@ -905,6 +909,23 @@ class Column(Condition):
|
|||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def parts(self) -> t.List[Identifier]:
|
||||
"""Return the parts of a column in order catalog, db, table, name."""
|
||||
return [part for part in reversed(list(self.args.values())) if part]
|
||||
|
||||
def to_dot(self) -> Dot:
|
||||
"""Converts the column into a dot expression."""
|
||||
parts = self.parts
|
||||
parent = self.parent
|
||||
|
||||
while parent:
|
||||
if isinstance(parent, Dot):
|
||||
parts.append(parent.expression)
|
||||
parent = parent.parent
|
||||
|
||||
return Dot.build(parts)
|
||||
|
||||
|
||||
class ColumnDef(Expression):
|
||||
arg_types = {
|
||||
|
@ -1033,6 +1054,113 @@ class Constraint(Expression):
|
|||
class Delete(Expression):
|
||||
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
|
||||
|
||||
def delete(
|
||||
self,
|
||||
table: ExpOrStr,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Delete:
|
||||
"""
|
||||
Create a DELETE expression or replace the table on an existing DELETE expression.
|
||||
|
||||
Example:
|
||||
>>> delete("tbl").sql()
|
||||
'DELETE FROM tbl'
|
||||
|
||||
Args:
|
||||
table: the table from which to delete.
|
||||
dialect: the dialect used to parse the input expression.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the modified expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=table,
|
||||
instance=self,
|
||||
arg="this",
|
||||
dialect=dialect,
|
||||
into=Table,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
||||
def where(
|
||||
self,
|
||||
*expressions: ExpOrStr,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Delete:
|
||||
"""
|
||||
Append to or set the WHERE expressions.
|
||||
|
||||
Example:
|
||||
>>> delete("tbl").where("x = 'a' OR x < 'b'").sql()
|
||||
"DELETE FROM tbl WHERE x = 'a' OR x < 'b'"
|
||||
|
||||
Args:
|
||||
*expressions: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
Multiple expressions are combined with an AND operator.
|
||||
append: if `True`, AND the new expressions to any existing expression.
|
||||
Otherwise, this resets the expression.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the modified expression.
|
||||
"""
|
||||
return _apply_conjunction_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="where",
|
||||
append=append,
|
||||
into=Where,
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
||||
def returning(
|
||||
self,
|
||||
expression: ExpOrStr,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Delete:
|
||||
"""
|
||||
Set the RETURNING expression. Not supported by all dialects.
|
||||
|
||||
Example:
|
||||
>>> delete("tbl").returning("*", dialect="postgres").sql()
|
||||
'DELETE FROM tbl RETURNING *'
|
||||
|
||||
Args:
|
||||
expression: the SQL code strings to parse.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the modified expression.
|
||||
"""
|
||||
return _apply_builder(
|
||||
expression=expression,
|
||||
instance=self,
|
||||
arg="returning",
|
||||
prefix="RETURNING",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into=Returning,
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
||||
class Drop(Expression):
|
||||
arg_types = {
|
||||
|
@ -1824,7 +1952,7 @@ class Union(Subqueryable):
|
|||
|
||||
def select(
|
||||
self,
|
||||
*expressions: str | Expression,
|
||||
*expressions: ExpOrStr,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
|
@ -2170,7 +2298,7 @@ class Select(Subqueryable):
|
|||
|
||||
def select(
|
||||
self,
|
||||
*expressions: str | Expression,
|
||||
*expressions: ExpOrStr,
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
|
@ -2875,6 +3003,20 @@ class Dot(Binary):
|
|||
def name(self) -> str:
|
||||
return self.expression.name
|
||||
|
||||
@classmethod
|
||||
def build(self, expressions: t.Sequence[Expression]) -> Dot:
|
||||
"""Build a Dot object with a sequence of expressions."""
|
||||
if len(expressions) < 2:
|
||||
raise ValueError(f"Dot requires >= 2 expressions.")
|
||||
|
||||
a, b, *expressions = expressions
|
||||
dot = Dot(this=a, expression=b)
|
||||
|
||||
for expression in expressions:
|
||||
dot = Dot(this=dot, expression=expression)
|
||||
|
||||
return dot
|
||||
|
||||
|
||||
class DPipe(Binary):
|
||||
pass
|
||||
|
@ -3049,7 +3191,7 @@ class TimeUnit(Expression):
|
|||
|
||||
def __init__(self, **args):
|
||||
unit = args.get("unit")
|
||||
if isinstance(unit, Column):
|
||||
if isinstance(unit, (Column, Literal)):
|
||||
args["unit"] = Var(this=unit.name)
|
||||
elif isinstance(unit, Week):
|
||||
unit.set("this", Var(this=unit.this.name))
|
||||
|
@ -3261,6 +3403,10 @@ class Count(AggFunc):
|
|||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class CountIf(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class CurrentDate(Func):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
@ -3407,6 +3553,10 @@ class Explode(Func):
|
|||
pass
|
||||
|
||||
|
||||
class ExponentialTimeDecayedAvg(AggFunc):
|
||||
arg_types = {"this": True, "time": False, "decay": False}
|
||||
|
||||
|
||||
class Floor(Func):
|
||||
arg_types = {"this": True, "decimals": False}
|
||||
|
||||
|
@ -3420,10 +3570,18 @@ class GroupConcat(Func):
|
|||
arg_types = {"this": True, "separator": False}
|
||||
|
||||
|
||||
class GroupUniqArray(AggFunc):
|
||||
arg_types = {"this": True, "size": False}
|
||||
|
||||
|
||||
class Hex(Func):
|
||||
pass
|
||||
|
||||
|
||||
class Histogram(AggFunc):
|
||||
arg_types = {"this": True, "bins": False}
|
||||
|
||||
|
||||
class If(Func):
|
||||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
@ -3493,7 +3651,11 @@ class Log10(Func):
|
|||
|
||||
|
||||
class LogicalOr(AggFunc):
|
||||
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
|
||||
_sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"]
|
||||
|
||||
|
||||
class LogicalAnd(AggFunc):
|
||||
_sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"]
|
||||
|
||||
|
||||
class Lower(Func):
|
||||
|
@ -3561,6 +3723,7 @@ class Quantile(AggFunc):
|
|||
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
|
||||
class Quantiles(AggFunc):
|
||||
arg_types = {"parameters": True, "expressions": True}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class QuantileIf(AggFunc):
|
||||
|
@ -3830,7 +3993,7 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
|||
|
||||
# Helpers
|
||||
def maybe_parse(
|
||||
sql_or_expression: str | Expression,
|
||||
sql_or_expression: ExpOrStr,
|
||||
*,
|
||||
into: t.Optional[IntoType] = None,
|
||||
dialect: DialectType = None,
|
||||
|
@ -4091,7 +4254,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
|
|||
return Except(this=left, expression=right, distinct=distinct)
|
||||
|
||||
|
||||
def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select:
|
||||
def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
|
||||
"""
|
||||
Initializes a syntax tree from one or multiple SELECT expressions.
|
||||
|
||||
|
@ -4135,7 +4298,14 @@ def from_(*expressions, dialect=None, **opts) -> Select:
|
|||
return Select().from_(*expressions, dialect=dialect, **opts)
|
||||
|
||||
|
||||
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
|
||||
def update(
|
||||
table: str | Table,
|
||||
properties: dict,
|
||||
where: t.Optional[ExpOrStr] = None,
|
||||
from_: t.Optional[ExpOrStr] = None,
|
||||
dialect: DialectType = None,
|
||||
**opts,
|
||||
) -> Update:
|
||||
"""
|
||||
Creates an update statement.
|
||||
|
||||
|
@ -4144,18 +4314,18 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
|
|||
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
|
||||
|
||||
Args:
|
||||
*properties (Dict[str, Any]): dictionary of properties to set which are
|
||||
*properties: dictionary of properties to set which are
|
||||
auto converted to sql objects eg None -> NULL
|
||||
where (str): sql conditional parsed into a WHERE statement
|
||||
from_ (str): sql statement parsed into a FROM statement
|
||||
dialect (str): the dialect used to parse the input expressions.
|
||||
where: sql conditional parsed into a WHERE statement
|
||||
from_: sql statement parsed into a FROM statement
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Update: the syntax tree for the UPDATE statement.
|
||||
"""
|
||||
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
|
||||
update.set(
|
||||
update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect))
|
||||
update_expr.set(
|
||||
"expressions",
|
||||
[
|
||||
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
|
||||
|
@ -4163,21 +4333,27 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
|
|||
],
|
||||
)
|
||||
if from_:
|
||||
update.set(
|
||||
update_expr.set(
|
||||
"from",
|
||||
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
|
||||
)
|
||||
if isinstance(where, Condition):
|
||||
where = Where(this=where)
|
||||
if where:
|
||||
update.set(
|
||||
update_expr.set(
|
||||
"where",
|
||||
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
|
||||
)
|
||||
return update
|
||||
return update_expr
|
||||
|
||||
|
||||
def delete(table, where=None, dialect=None, **opts) -> Delete:
|
||||
def delete(
|
||||
table: ExpOrStr,
|
||||
where: t.Optional[ExpOrStr] = None,
|
||||
returning: t.Optional[ExpOrStr] = None,
|
||||
dialect: DialectType = None,
|
||||
**opts,
|
||||
) -> Delete:
|
||||
"""
|
||||
Builds a delete statement.
|
||||
|
||||
|
@ -4186,19 +4362,20 @@ def delete(table, where=None, dialect=None, **opts) -> Delete:
|
|||
'DELETE FROM my_table WHERE id > 1'
|
||||
|
||||
Args:
|
||||
where (str|Condition): sql conditional parsed into a WHERE statement
|
||||
dialect (str): the dialect used to parse the input expressions.
|
||||
where: sql conditional parsed into a WHERE statement
|
||||
returning: sql conditional parsed into a RETURNING statement
|
||||
dialect: the dialect used to parse the input expressions.
|
||||
**opts: other options to use to parse the input expressions.
|
||||
|
||||
Returns:
|
||||
Delete: the syntax tree for the DELETE statement.
|
||||
"""
|
||||
return Delete(
|
||||
this=maybe_parse(table, into=Table, dialect=dialect, **opts),
|
||||
where=Where(this=where)
|
||||
if isinstance(where, Condition)
|
||||
else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
|
||||
)
|
||||
delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts)
|
||||
if where:
|
||||
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
|
||||
if returning:
|
||||
delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
|
||||
return delete_expr
|
||||
|
||||
|
||||
def condition(expression, dialect=None, **opts) -> Condition:
|
||||
|
@ -4414,7 +4591,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
|
|||
|
||||
|
||||
def alias_(
|
||||
expression: str | Expression,
|
||||
expression: ExpOrStr,
|
||||
alias: str | Identifier,
|
||||
table: bool | t.Sequence[str | Identifier] = False,
|
||||
quoted: t.Optional[bool] = None,
|
||||
|
@ -4516,7 +4693,7 @@ def column(
|
|||
)
|
||||
|
||||
|
||||
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||
def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
|
@ -4595,7 +4772,7 @@ def values(
|
|||
)
|
||||
|
||||
|
||||
def var(name: t.Optional[str | Expression]) -> Var:
|
||||
def var(name: t.Optional[ExpOrStr]) -> Var:
|
||||
"""Build a SQL variable.
|
||||
|
||||
Example:
|
||||
|
@ -4612,7 +4789,7 @@ def var(name: t.Optional[str | Expression]) -> Var:
|
|||
The new variable node.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError(f"Cannot convert empty name into var.")
|
||||
raise ValueError("Cannot convert empty name into var.")
|
||||
|
||||
if isinstance(name, Expression):
|
||||
name = name.name
|
||||
|
@ -4682,7 +4859,7 @@ def convert(value) -> Expression:
|
|||
raise ValueError(f"Cannot convert {value}")
|
||||
|
||||
|
||||
def replace_children(expression, fun):
|
||||
def replace_children(expression, fun, *args, **kwargs):
|
||||
"""
|
||||
Replace children of an expression with the result of a lambda fun(child) -> exp.
|
||||
"""
|
||||
|
@ -4694,7 +4871,7 @@ def replace_children(expression, fun):
|
|||
|
||||
for cn in child_nodes:
|
||||
if isinstance(cn, Expression):
|
||||
for child_node in ensure_collection(fun(cn)):
|
||||
for child_node in ensure_collection(fun(cn, *args, **kwargs)):
|
||||
new_child_nodes.append(child_node)
|
||||
child_node.parent = expression
|
||||
child_node.arg_key = k
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -25,8 +25,7 @@ class Generator:
|
|||
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
|
||||
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
|
||||
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
|
||||
identify (bool): if set to True all identifiers will be delimited by the corresponding
|
||||
character.
|
||||
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
|
||||
normalize (bool): if set to True all identifiers will lower cased
|
||||
string_escape (str): specifies a string escape character. Default: '.
|
||||
identifier_escape (str): specifies an identifier escape character. Default: ".
|
||||
|
@ -57,10 +56,10 @@ class Generator:
|
|||
|
||||
TRANSFORMS = {
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", e.this, e.expression, e.args.get("unit")
|
||||
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.TsOrDsAdd: lambda self, e: self.func(
|
||||
"TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
|
||||
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
|
||||
),
|
||||
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
|
@ -736,7 +735,7 @@ class Generator:
|
|||
text = expression.name
|
||||
text = text.lower() if self.normalize else text
|
||||
text = text.replace(self.identifier_end, self._escaped_identifier_end)
|
||||
if expression.args.get("quoted") or self.identify:
|
||||
if expression.args.get("quoted") or should_identify(text, self.identify):
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
|
@ -1176,6 +1175,22 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def setitem_sql(self, expression: exp.SetItem) -> str:
|
||||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} " if kind else ""
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression)
|
||||
collate = self.sql(expression, "collate")
|
||||
collate = f" COLLATE {collate}" if collate else ""
|
||||
global_ = "GLOBAL " if expression.args.get("global") else ""
|
||||
return f"{global_}{kind}{this}{expressions}{collate}"
|
||||
|
||||
def set_sql(self, expression: exp.Set) -> str:
|
||||
expressions = (
|
||||
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
|
||||
)
|
||||
return f"SET{expressions}"
|
||||
|
||||
def lock_sql(self, expression: exp.Lock) -> str:
|
||||
if self.LOCKING_READS_SUPPORTED:
|
||||
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
|
||||
|
@ -1359,8 +1374,8 @@ class Generator:
|
|||
sql = self.query_modifiers(
|
||||
expression,
|
||||
self.wrap(expression),
|
||||
self.expressions(expression, key="pivots", sep=" "),
|
||||
alias,
|
||||
self.expressions(expression, key="pivots", sep=" "),
|
||||
)
|
||||
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
@ -1668,7 +1683,7 @@ class Generator:
|
|||
expression_sql = self.sql(expression, "expression")
|
||||
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
|
||||
|
||||
def transaction_sql(self, *_) -> str:
|
||||
def transaction_sql(self, expression: exp.Transaction) -> str:
|
||||
return "BEGIN"
|
||||
|
||||
def commit_sql(self, expression: exp.Commit) -> str:
|
||||
|
|
|
@ -403,3 +403,20 @@ def first(it: t.Iterable[T]) -> T:
|
|||
Useful for sets.
|
||||
"""
|
||||
return next(i for i in it)
|
||||
|
||||
|
||||
def should_identify(text: str, identify: str | bool) -> bool:
|
||||
"""Checks if text should be identified given an identify option.
|
||||
|
||||
Args:
|
||||
text: the text to check.
|
||||
identify: "always" | True - always returns true, "safe" - true if no upper case
|
||||
|
||||
Returns:
|
||||
Whether or not a string should be identified.
|
||||
"""
|
||||
if identify is True or identify == "always":
|
||||
return True
|
||||
if identify == "safe":
|
||||
return not any(char.isupper() for char in text)
|
||||
return False
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import should_identify
|
||||
|
||||
|
||||
def canonicalize(expression: exp.Expression) -> exp.Expression:
|
||||
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
|
||||
"""Converts a sql expression into a standard form.
|
||||
|
||||
This method relies on annotate_types because many of the
|
||||
|
@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
Args:
|
||||
expression: The expression to canonicalize.
|
||||
identify: Whether or not to force identify identifier.
|
||||
"""
|
||||
exp.replace_children(expression, canonicalize)
|
||||
exp.replace_children(expression, canonicalize, identify=identify)
|
||||
|
||||
expression = add_text_to_concat(expression)
|
||||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
expression = ensure_bool_predicates(expression)
|
||||
|
||||
if isinstance(expression, exp.Identifier):
|
||||
expression.set("quoted", True)
|
||||
if should_identify(expression.this, identify):
|
||||
expression.set("quoted", True)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
||||
if isinstance(expression, exp.Connector):
|
||||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
_replace_int_predicate(expression.this)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
||||
for a, b in itertools.permutations([a, b]):
|
||||
if (
|
||||
|
@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
|
|||
cast = exp.Cast(this=node.copy(), to=data_type)
|
||||
cast.type = data_type
|
||||
node.replace(cast)
|
||||
|
||||
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from collections import defaultdict
|
||||
|
||||
from sqlglot import alias, exp
|
||||
from sqlglot.helper import flatten
|
||||
from sqlglot.optimizer.qualify_columns import Resolver
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
|
@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
else:
|
||||
order_refs = set()
|
||||
|
||||
new_selections = defaultdict(list)
|
||||
new_selections = []
|
||||
removed = False
|
||||
star = False
|
||||
|
||||
for selection in scope.selects:
|
||||
name = selection.alias_or_name
|
||||
|
||||
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
|
||||
new_selections[name].append(selection)
|
||||
new_selections.append(selection)
|
||||
else:
|
||||
if selection.is_star:
|
||||
star = True
|
||||
|
@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema):
|
|||
|
||||
if star:
|
||||
resolver = Resolver(scope, schema)
|
||||
names = {s.alias_or_name for s in new_selections}
|
||||
|
||||
for name in sorted(parent_selections):
|
||||
if name not in new_selections:
|
||||
new_selections[name].append(
|
||||
alias(exp.column(name, table=resolver.get_table(name)), name)
|
||||
)
|
||||
if name not in names:
|
||||
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
|
||||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections[""].append(DEFAULT_SELECTION())
|
||||
new_selections.append(DEFAULT_SELECTION())
|
||||
|
||||
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
|
||||
scope.expression.select(*new_selections, append=False, copy=False)
|
||||
|
||||
if removed:
|
||||
scope.clear_cache()
|
||||
|
|
|
@ -37,6 +37,7 @@ def qualify_columns(expression, schema):
|
|||
_qualify_outputs(scope)
|
||||
_expand_group_by(scope, resolver)
|
||||
_expand_order_by(scope)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver):
|
|||
# column_table can be a '' because bigquery unnest has no table alias
|
||||
if column_table:
|
||||
column.set("table", column_table)
|
||||
elif column_table not in scope.sources:
|
||||
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
|
||||
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
|
||||
|
||||
root, *parts = column.parts
|
||||
|
||||
if root.name in scope.sources:
|
||||
# struct is already qualified, but we still need to change the AST representation
|
||||
column_table = root
|
||||
root, *parts = parts
|
||||
else:
|
||||
column_table = resolver.get_table(root.name)
|
||||
|
||||
if column_table:
|
||||
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
|
||||
|
||||
columns_missing_from_scope = []
|
||||
# Determine whether each reference in the order by clause is to a column or an alias.
|
||||
|
@ -373,10 +389,14 @@ class Resolver:
|
|||
if isinstance(node, exp.Subqueryable):
|
||||
while node and node.alias != table_name:
|
||||
node = node.parent
|
||||
|
||||
node_alias = node.args.get("alias")
|
||||
if node_alias:
|
||||
return node_alias.this
|
||||
return exp.to_identifier(table_name)
|
||||
|
||||
return exp.to_identifier(
|
||||
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
|
||||
)
|
||||
|
||||
@property
|
||||
def all_columns(self):
|
||||
|
|
|
@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
|
||||
scope.rename_source(None, alias_)
|
||||
|
||||
for source in scope.sources.values():
|
||||
for name, source in scope.sources.items():
|
||||
if isinstance(source, exp.Table):
|
||||
identifier = isinstance(source.this, exp.Identifier)
|
||||
|
||||
if identifier:
|
||||
if isinstance(source.this, exp.Identifier):
|
||||
if not source.args.get("db"):
|
||||
source.set("db", exp.to_identifier(db))
|
||||
if not source.args.get("catalog"):
|
||||
|
@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
|||
source = source.replace(
|
||||
alias(
|
||||
source.copy(),
|
||||
source.this if identifier else next_name(),
|
||||
name if name else next_name(),
|
||||
table=True,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import find_new_name
|
||||
|
||||
|
||||
class ScopeType(Enum):
|
||||
|
@ -293,6 +294,8 @@ class Scope:
|
|||
result = {}
|
||||
|
||||
for name, node in referenced_names:
|
||||
if name in result:
|
||||
raise OptimizeError(f"Alias already used: {name}")
|
||||
if name in self.sources:
|
||||
result[name] = (node, self.sources[name])
|
||||
|
||||
|
@ -594,6 +597,8 @@ def _traverse_tables(scope):
|
|||
if table_name in scope.sources:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table.
|
||||
sources[source_name] = scope.sources[table_name]
|
||||
elif source_name in sources:
|
||||
sources[find_new_name(sources, table_name)] = expression
|
||||
else:
|
||||
sources[source_name] = expression
|
||||
continue
|
||||
|
|
|
@ -96,6 +96,7 @@ class Parser(metaclass=_Parser):
|
|||
NO_PAREN_FUNCTIONS = {
|
||||
TokenType.CURRENT_DATE: exp.CurrentDate,
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDate,
|
||||
TokenType.CURRENT_TIME: exp.CurrentTime,
|
||||
TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp,
|
||||
}
|
||||
|
||||
|
@ -198,7 +199,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.COMMIT,
|
||||
TokenType.COMPOUND,
|
||||
TokenType.CONSTRAINT,
|
||||
TokenType.CURRENT_TIME,
|
||||
TokenType.DEFAULT,
|
||||
TokenType.DELETE,
|
||||
TokenType.DESCRIBE,
|
||||
|
@ -370,8 +370,9 @@ class Parser(metaclass=_Parser):
|
|||
LAMBDAS = {
|
||||
TokenType.ARROW: lambda self, expressions: self.expression(
|
||||
exp.Lambda,
|
||||
this=self._parse_conjunction().transform(
|
||||
self._replace_lambda, {node.name for node in expressions}
|
||||
this=self._replace_lambda(
|
||||
self._parse_conjunction(),
|
||||
{node.name for node in expressions},
|
||||
),
|
||||
expressions=expressions,
|
||||
),
|
||||
|
@ -441,6 +442,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.With: lambda self: self._parse_with(),
|
||||
exp.Window: lambda self: self._parse_named_window(),
|
||||
exp.Qualify: lambda self: self._parse_qualify(),
|
||||
exp.Returning: lambda self: self._parse_returning(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
|
||||
}
|
||||
|
||||
|
@ -460,6 +462,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
|
||||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
TokenType.UPDATE: lambda self: self._parse_update(),
|
||||
TokenType.USE: lambda self: self.expression(
|
||||
|
@ -656,15 +659,15 @@ class Parser(metaclass=_Parser):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
|
||||
"TRY_CONVERT": lambda self: self._parse_convert(False),
|
||||
"EXTRACT": lambda self: self._parse_extract(),
|
||||
"POSITION": lambda self: self._parse_position(),
|
||||
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||
"SUBSTRING": lambda self: self._parse_substring(),
|
||||
"TRIM": lambda self: self._parse_trim(),
|
||||
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
|
||||
"TRY_CAST": lambda self: self._parse_cast(False),
|
||||
"STRING_AGG": lambda self: self._parse_string_agg(),
|
||||
"TRY_CONVERT": lambda self: self._parse_convert(False),
|
||||
}
|
||||
|
||||
QUERY_MODIFIER_PARSERS = {
|
||||
|
@ -684,13 +687,28 @@ class Parser(metaclass=_Parser):
|
|||
"sample": lambda self: self._parse_table_sample(as_modifier=True),
|
||||
}
|
||||
|
||||
SET_PARSERS = {
|
||||
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
|
||||
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
|
||||
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
|
||||
"TRANSACTION": lambda self: self._parse_set_transaction(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
SET_PARSERS: t.Dict[str, t.Callable] = {}
|
||||
|
||||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||
|
||||
TRANSACTION_CHARACTERISTICS = {
|
||||
"ISOLATION LEVEL REPEATABLE READ",
|
||||
"ISOLATION LEVEL READ COMMITTED",
|
||||
"ISOLATION LEVEL READ UNCOMMITTED",
|
||||
"ISOLATION LEVEL SERIALIZABLE",
|
||||
"READ WRITE",
|
||||
"READ ONLY",
|
||||
}
|
||||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||
|
@ -1775,11 +1793,12 @@ class Parser(metaclass=_Parser):
|
|||
self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
any_token = self._match(TokenType.ALIAS)
|
||||
alias = self._parse_id_var(
|
||||
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
|
||||
alias = (
|
||||
self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
|
||||
or self._parse_string_as_identifier()
|
||||
)
|
||||
index = self._index
|
||||
|
||||
index = self._index
|
||||
if self._match(TokenType.L_PAREN):
|
||||
columns = self._parse_csv(self._parse_function_parameter)
|
||||
self._match_r_paren() if columns else self._retreat(index)
|
||||
|
@ -2046,7 +2065,12 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
|
||||
catalog = None
|
||||
db = None
|
||||
table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
|
||||
|
||||
table = (
|
||||
(not schema and self._parse_function())
|
||||
or self._parse_id_var(any_token=False)
|
||||
or self._parse_string_as_identifier()
|
||||
)
|
||||
|
||||
while self._match(TokenType.DOT):
|
||||
if catalog:
|
||||
|
@ -2085,6 +2109,8 @@ class Parser(metaclass=_Parser):
|
|||
subquery = self._parse_select(table=True)
|
||||
|
||||
if subquery:
|
||||
if not subquery.args.get("pivots"):
|
||||
subquery.set("pivots", self._parse_pivots())
|
||||
return subquery
|
||||
|
||||
this = self._parse_table_parts(schema=schema)
|
||||
|
@ -3370,9 +3396,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_window(
|
||||
self, this: t.Optional[exp.Expression], alias: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.FILTER):
|
||||
where = self._parse_wrapped(self._parse_where)
|
||||
this = self.expression(exp.Filter, this=this, expression=where)
|
||||
if self._match_pair(TokenType.FILTER, TokenType.L_PAREN):
|
||||
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
|
||||
self._match_r_paren()
|
||||
|
||||
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
|
||||
|
@ -3504,6 +3530,9 @@ class Parser(metaclass=_Parser):
|
|||
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]:
|
||||
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
|
||||
|
||||
def _parse_number(self) -> t.Optional[exp.Expression]:
|
||||
if self._match(TokenType.NUMBER):
|
||||
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
|
||||
|
@ -3778,23 +3807,6 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
return self._parse_as_command(start)
|
||||
|
||||
def _parse_show(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
|
||||
if parser:
|
||||
return parser(self)
|
||||
self._advance()
|
||||
return self.expression(exp.Show, this=self._prev.text.upper())
|
||||
|
||||
def _default_parse_set_item(self) -> exp.Expression:
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=self._parse_statement(),
|
||||
)
|
||||
|
||||
def _parse_set_item(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
|
||||
return parser(self) if parser else self._default_parse_set_item()
|
||||
|
||||
def _parse_merge(self) -> exp.Expression:
|
||||
self._match(TokenType.INTO)
|
||||
target = self._parse_table()
|
||||
|
@ -3861,8 +3873,71 @@ class Parser(metaclass=_Parser):
|
|||
expressions=whens,
|
||||
)
|
||||
|
||||
def _parse_show(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
|
||||
if parser:
|
||||
return parser(self)
|
||||
self._advance()
|
||||
return self.expression(exp.Show, this=self._prev.text.upper())
|
||||
|
||||
def _parse_set_item_assignment(
|
||||
self, kind: t.Optional[str] = None
|
||||
) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
|
||||
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
|
||||
if not self._match_texts(("=", "TO")):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
right = self._parse_statement() or self._parse_id_var()
|
||||
this = self.expression(
|
||||
exp.EQ,
|
||||
this=left,
|
||||
expression=right,
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=this,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
|
||||
self._match_text_seq("TRANSACTION")
|
||||
characteristics = self._parse_csv(
|
||||
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
|
||||
)
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
expressions=characteristics,
|
||||
kind="TRANSACTION",
|
||||
**{"global": global_}, # type: ignore
|
||||
)
|
||||
|
||||
def _parse_set_item(self) -> t.Optional[exp.Expression]:
|
||||
parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
|
||||
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
|
||||
|
||||
def _parse_set(self) -> exp.Expression:
|
||||
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||
index = self._index
|
||||
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||
|
||||
if self._curr:
|
||||
self._retreat(index)
|
||||
return self._parse_as_command(self._prev)
|
||||
|
||||
return set_
|
||||
|
||||
def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]:
|
||||
for option in options:
|
||||
if self._match_text_seq(*option.split(" ")):
|
||||
return exp.Var(this=option)
|
||||
return None
|
||||
|
||||
def _parse_as_command(self, start: Token) -> exp.Command:
|
||||
while self._curr:
|
||||
|
@ -3874,6 +3949,9 @@ class Parser(metaclass=_Parser):
|
|||
def _find_parser(
|
||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||
) -> t.Optional[t.Callable]:
|
||||
if not self._curr:
|
||||
return None
|
||||
|
||||
index = self._index
|
||||
this = []
|
||||
while True:
|
||||
|
@ -3973,7 +4051,16 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
def _replace_lambda(self, node, lambda_variables):
|
||||
if isinstance(node, exp.Column):
|
||||
if node.name in lambda_variables:
|
||||
return node.this
|
||||
for column in node.find_all(exp.Column):
|
||||
if column.parts[0].name in lambda_variables:
|
||||
dot_or_id = column.to_dot() if column.table else column.this
|
||||
parent = column.parent
|
||||
|
||||
while isinstance(parent, exp.Dot):
|
||||
if not isinstance(parent.parent, exp.Dot):
|
||||
parent.replace(dot_or_id)
|
||||
break
|
||||
parent = parent.parent
|
||||
else:
|
||||
column.replace(dot_or_id)
|
||||
return node
|
||||
|
|
|
@ -502,6 +502,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"CUBE": TokenType.CUBE,
|
||||
"CURRENT_DATE": TokenType.CURRENT_DATE,
|
||||
"CURRENT ROW": TokenType.CURRENT_ROW,
|
||||
"CURRENT_TIME": TokenType.CURRENT_TIME,
|
||||
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
|
||||
"DATABASE": TokenType.DATABASE,
|
||||
"DEFAULT": TokenType.DEFAULT,
|
||||
|
@ -725,7 +726,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
TokenType.COMMAND,
|
||||
TokenType.EXECUTE,
|
||||
TokenType.FETCH,
|
||||
TokenType.SET,
|
||||
TokenType.SHOW,
|
||||
}
|
||||
|
||||
|
@ -851,8 +851,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
# If we have either a semicolon or a begin token before the command's token, we'll parse
|
||||
# whatever follows the command's token as a string
|
||||
if token_type in self.COMMANDS and (
|
||||
len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS
|
||||
if (
|
||||
token_type in self.COMMANDS
|
||||
and self._peek != ";"
|
||||
and (len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS)
|
||||
):
|
||||
start = self._current
|
||||
tokens = len(self.tokens)
|
||||
|
|
|
@ -2,13 +2,12 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.helper import find_new_name
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.generator import Generator
|
||||
|
||||
from sqlglot import expressions as exp
|
||||
|
||||
|
||||
def unalias_group(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
|
@ -61,8 +60,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
and expression.args["distinct"].args.get("on")
|
||||
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
|
||||
):
|
||||
distinct_cols = expression.args["distinct"].args["on"].expressions
|
||||
expression.args["distinct"].pop()
|
||||
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
|
||||
outer_selects = expression.selects
|
||||
row_number = find_new_name(expression.named_selects, "_row_number")
|
||||
window = exp.Window(
|
||||
|
@ -71,14 +69,49 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
)
|
||||
order = expression.args.get("order")
|
||||
if order:
|
||||
window.set("order", order.copy())
|
||||
order.pop()
|
||||
window.set("order", order.pop().copy())
|
||||
window = exp.alias_(window, row_number)
|
||||
expression.select(window, copy=False)
|
||||
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
|
||||
return expression
|
||||
|
||||
|
||||
def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
|
||||
|
||||
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
|
||||
https://docs.snowflake.com/en/sql-reference/constructs/qualify
|
||||
|
||||
Some dialects don't support window functions in the WHERE clause, so we need to include them as
|
||||
projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
|
||||
if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
|
||||
otherwise we won't be able to refer to it in the outer query's WHERE clause.
|
||||
"""
|
||||
if isinstance(expression, exp.Select) and expression.args.get("qualify"):
|
||||
taken = set(expression.named_selects)
|
||||
for select in expression.selects:
|
||||
if not select.alias_or_name:
|
||||
alias = find_new_name(taken, "_c")
|
||||
select.replace(exp.alias_(select.copy(), alias))
|
||||
taken.add(alias)
|
||||
|
||||
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
|
||||
qualify_filters = expression.args["qualify"].pop().this
|
||||
|
||||
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
|
||||
if isinstance(expr, exp.Window):
|
||||
alias = find_new_name(expression.named_selects, "_w")
|
||||
expression.select(exp.alias_(expr.copy(), alias), copy=False)
|
||||
expr.replace(exp.column(alias))
|
||||
elif expr.name not in expression.named_selects:
|
||||
expression.select(expr.copy(), copy=False)
|
||||
|
||||
return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
|
||||
|
@ -139,6 +172,7 @@ def delegate(attr: str) -> t.Callable:
|
|||
|
||||
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
|
||||
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
|
||||
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
|
||||
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
|
||||
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue