1
0
Fork 0

Merging upstream version 11.4.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:46:19 +01:00
parent ecb42ec17f
commit 63746a3e92
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
89 changed files with 35352 additions and 33081 deletions

View file

@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "11.3.6"
__version__ = "11.4.1"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -1,5 +1,8 @@
from __future__ import annotations
import argparse
import sys
import typing as t
import sqlglot
@ -42,6 +45,12 @@ parser.add_argument(
action="store_true",
help="Parse and return the expression tree",
)
parser.add_argument(
"--tokenize",
dest="tokenize",
action="store_true",
help="Tokenize and return the tokens list",
)
parser.add_argument(
"--error-level",
dest="error_level",
@ -57,7 +66,7 @@ error_level = sqlglot.ErrorLevel[args.error_level.upper()]
sql = sys.stdin.read() if args.sql == "-" else args.sql
if args.parse:
sqls = [
objs: t.Union[t.List[str], t.List[sqlglot.tokens.Token]] = [
repr(expression)
for expression in sqlglot.parse(
sql,
@ -65,8 +74,10 @@ if args.parse:
error_level=error_level,
)
]
elif args.tokenize:
objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql)
else:
sqls = sqlglot.transpile(
objs = sqlglot.transpile(
sql,
read=args.read,
write=args.write,
@ -75,5 +86,5 @@ else:
error_level=error_level,
)
for sql in sqls:
print(sql)
for obj in objs:
print(obj)

View file

@ -299,7 +299,7 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = optimize_func(select_expression)
select_expression = optimize_func(select_expression, identify="always")
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:

View file

@ -144,7 +144,6 @@ class BigQuery(Dialect):
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
@ -194,7 +193,6 @@ class BigQuery(Dialect):
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.errors import ParseError
from sqlglot.helper import ensure_list, seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@ -40,7 +41,18 @@ class ClickHouse(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
this=seq_get(args, 0),
time=seq_get(args, 1),
decay=seq_get(params, 0),
),
"MAP": parse_var_map,
"HISTOGRAM": lambda params, args: exp.Histogram(
this=seq_get(args, 0), bins=seq_get(params, 0)
),
"GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
this=seq_get(args, 0), size=seq_get(params, 0)
),
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
@ -113,22 +125,40 @@ class ClickHouse(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
EXPLICIT_UNION = True
def _param_args_sql(
self, expression: exp.Expression, params_name: str, args_name: str
self,
expression: exp.Expression,
param_names: str | t.List[str],
arg_names: str | t.List[str],
) -> str:
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
params = self.format_args(
*(
arg
for name in ensure_list(param_names)
for arg in ensure_list(expression.args.get(name))
)
)
args = self.format_args(
*(
arg
for name in ensure_list(arg_names)
for arg in ensure_list(expression.args.get(name))
)
)
return f"({params})({args})"
def cte_sql(self, expression: exp.CTE) -> str:

View file

@ -23,6 +23,7 @@ class Databricks(Spark):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"

View file

@ -8,7 +8,7 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
from sqlglot.tokens import Token, Tokenizer
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect):
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into(
self, expression_type: exp.IntoType, sql: str, **opts
) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect):
def transpile(self, sql: str, **opts) -> t.List[str]:
return [self.generate(expression, **opts) for expression in self.parse(sql)]
def tokenize(self, sql: str) -> t.List[Token]:
return self.tokenizer.tokenize(sql)
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
@ -385,6 +388,21 @@ def parse_date_delta(
return inner_func
def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
return exp.DateTrunc(unit=unit, this=this)
return exp.TimestampTrunc(this=this, unit=unit)
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
"DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
)
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this
if isinstance(expression.this, exp.Distinct):
cond = expression.this.expressions[0]
self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
return self.func("sum", exp.func("if", cond, 1, 0))
def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")

View file

@ -97,6 +97,7 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}

View file

@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
@ -148,6 +149,9 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: self.func(
@ -162,6 +166,7 @@ class DuckDB(Dialect):
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
@ -175,6 +180,7 @@ class DuckDB(Dialect):
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
@ -186,6 +192,7 @@ class DuckDB(Dialect):
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = {

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _add_date_sql(self, expression):
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
@ -47,7 +49,7 @@ def _add_date_sql(self, expression):
return self.func(func, expression.this, modified_increment.this)
def _date_diff_sql(self, expression):
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
@ -56,21 +58,21 @@ def _date_diff_sql(self, expression):
return f"{diff_sql}{multiplier_sql}"
def _array_sort(self, expression):
def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
def _property_sql(self, expression):
def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix(self, expression):
def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
def _str_to_date(self, expression):
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -78,7 +80,7 @@ def _str_to_date(self, expression):
return f"CAST({this} AS DATE)"
def _str_to_time(self, expression):
def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@ -86,20 +88,22 @@ def _str_to_time(self, expression):
return f"CAST({this} AS TIMESTAMP)"
def _time_format(self, expression):
def _time_format(
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Hive.time_format:
return None
return time_format
def _time_to_str(self, expression):
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
def _to_date_sql(self, expression):
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.time_format, Hive.date_format):
@ -107,7 +111,7 @@ def _to_date_sql(self, expression):
return f"TO_DATE({this})"
def _unnest_to_explode_sql(self, expression):
def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression):
exp.Lateral(
this=udtf(this=expression),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]),
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
)
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression):
return self.join_sql(expression)
def _index_sql(self, expression):
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
@ -263,14 +267,15 @@ class Hive(Dialect):
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
@ -333,13 +338,19 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
def with_properties(self, properties):
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
)
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(
properties,
prefix=self.seg("TBLPROPERTIES"),
)
def datatype_sql(self, expression):
def datatype_sql(self, expression: exp.DataType) -> str:
if (
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions

View file

@ -177,7 +177,7 @@ class MySQL(Dialect):
"@@": TokenType.SESSION_PARAMETER,
}
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
@ -211,7 +211,6 @@ class MySQL(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
SHOW_PARSERS = {
@ -269,15 +268,12 @@ class MySQL(Dialect):
}
SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
**parser.Parser.SET_PARSERS,
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
"TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@ -292,15 +288,6 @@ class MySQL(Dialect):
"SWAPS",
}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
"ISOLATION LEVEL READ UNCOMMITTED",
"ISOLATION LEVEL SERIALIZABLE",
"READ WRITE",
"READ ONLY",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@ -354,12 +341,6 @@ class MySQL(Dialect):
**{"global": global_},
)
def _parse_var_from_options(self, options):
for option in options:
if self._match_text_seq(*option.split(" ")):
return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
@ -372,30 +353,6 @@ class MySQL(Dialect):
offset = parts[0]
return offset, limit
def _default_parse_set_item(self):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
right = self._parse_statement() or self._parse_id_var()
this = self.expression(
exp.EQ,
this=left,
expression=right,
)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_charset(self, kind):
this = self._parse_string() or self._parse_id_var()
@ -418,18 +375,6 @@ class MySQL(Dialect):
kind="NAMES",
)
def _parse_set_transaction(self, global_=False):
self._match_text_seq("TRANSACTION")
characteristics = self._parse_csv(
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
)
return self.expression(
exp.SetItem,
expressions=characteristics,
kind="TRANSACTION",
**{"global": global_},
)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@ -523,16 +468,3 @@ class MySQL(Dialect):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
def setitem_sql(self, expression):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
str_position_sql,
timestamptrunc_sql,
trim_sql,
)
from sqlglot.helper import seq_get
@ -34,7 +35,7 @@ def _date_add_sql(kind):
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
unit = self.sql(expression, "unit")
unit = expression.args.get("unit")
expression = simplify(expression.args["expression"])
if not isinstance(expression, exp.Literal):
@ -92,8 +93,7 @@ def _string_agg_sql(self, expression):
this = expression.this
if isinstance(this, exp.Order):
if this.this:
this = this.this
this.pop()
this = this.this.pop()
order = self.sql(expression.this) # Order has a leading space
return f"STRING_AGG({self.format_args(this, separator)}{order})"
@ -256,6 +256,9 @@ class Postgres(Dialect):
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
}
BITWISE = {
@ -311,6 +314,7 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@ -320,6 +324,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,

View file

@ -3,12 +3,14 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
format_time_lambda,
if_sql,
no_ilike_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
@ -98,10 +100,16 @@ def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_add_sql(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit") or "'day'"
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
return self.func(
"DATE_ADD",
exp.Literal.string(expression.text("unit") or "day"),
expression.expression,
self.func(
"DATE_PARSE",
self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
Presto.date_format,
),
)
def _sequence_sql(self, expression):
@ -195,6 +203,7 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"FROM_UNIXTIME": _from_unixtime,
"NOW": exp.CurrentTimestamp.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
@ -237,6 +246,7 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@ -250,8 +260,12 @@ class Presto(Dialect):
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.Decode: _decode_sql,
@ -265,6 +279,7 @@ class Presto(Dialect):
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
@ -277,6 +292,7 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",

View file

@ -20,6 +20,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@ -76,13 +81,16 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
),
exp.DateDiff: lambda self, e: self.func(
"DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
"DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)

View file

@ -5,11 +5,13 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
min_or_least,
rename_func,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
var_map_sql,
@ -176,6 +178,7 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@ -186,10 +189,6 @@ class Snowflake(Dialect):
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
"DECODE": exp.Matches.from_arg_list,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
@ -280,6 +279,8 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
@ -287,6 +288,7 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),

View file

@ -157,6 +157,7 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),

View file

@ -1,10 +1,11 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
count_if_to_sum,
no_ilike_sql,
no_tablesample_sql,
no_trycast_sql,
@ -13,23 +14,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
@ -78,20 +62,32 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.TableSample: no_tablesample_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
}
def cast_sql(self, expression: exp.Cast) -> str:
if expression.to.this == exp.DataType.Type.DATE:
return self.func("DATE", expression.this)
return super().cast_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = expression.args.get("unit")
unit = unit.name.upper() if unit else "DAY"
@ -119,16 +115,32 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
def fetch_sql(self, expression):
def fetch_sql(self, expression: exp.Fetch) -> str:
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
def least_sql(self, expression):
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def groupconcat_sql(self, expression):
this = expression.this
distinct = expression.find(exp.Distinct)
if distinct:
this = distinct.expressions[0]
distinct = "DISTINCT "
if isinstance(expression.this, exp.Order):
self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
if expression.this.this and not distinct:
this = expression.this.this
separator = expression.args.get("separator")
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
def least_sql(self, expression: exp.Least) -> str:
if len(expression.expressions) > 1:
return rename_func("MIN")(self, expression)
return self.expressions(expression)
def transaction_sql(self, expression):
def transaction_sql(self, expression: exp.Transaction) -> str:
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"

View file

@ -3,9 +3,18 @@ from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
class StarRocks(MySQL):
class Parser(MySQL.Parser): # type: ignore
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
}
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING, # type: ignore
@ -20,6 +29,9 @@ class StarRocks(MySQL):
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),

View file

@ -117,14 +117,12 @@ def _string_agg_sql(self, e):
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.expressions[0]
distinct.pop()
this = distinct.pop().expressions[0]
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
this = e.this.this
e.this.this.pop()
this = e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")

View file

@ -301,7 +301,7 @@ class Expression(metaclass=_Expression):
the specified types.
Args:
expression_types (type): the expression type(s) to match.
expression_types: the expression type(s) to match.
Returns:
The node which matches the criteria or None if no such node was found.
@ -314,7 +314,7 @@ class Expression(metaclass=_Expression):
yields those that match at least one of the specified expression types.
Args:
expression_types (type): the expression type(s) to match.
expression_types: the expression type(s) to match.
Returns:
The generator object.
@ -328,7 +328,7 @@ class Expression(metaclass=_Expression):
Returns a nearest parent matching expression_types.
Args:
expression_types (type): the expression type(s) to match.
expression_types: the expression type(s) to match.
Returns:
The parent node.
@ -336,8 +336,7 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
# ignore type because mypy doesn't know that we're checking type in the loop
return ancestor # type: ignore[return-value]
return t.cast(E, ancestor)
@property
def parent_select(self):
@ -549,8 +548,12 @@ class Expression(metaclass=_Expression):
def pop(self):
"""
Remove this expression from its AST.
Returns:
The popped expression.
"""
self.replace(None)
return self
def assert_is(self, type_):
"""
@ -626,6 +629,7 @@ IntoType = t.Union[
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
@ -809,7 +813,7 @@ class Describe(Expression):
class Set(Expression):
arg_types = {"expressions": True}
arg_types = {"expressions": False}
class SetItem(Expression):
@ -905,6 +909,23 @@ class Column(Condition):
def output_name(self) -> str:
return self.name
@property
def parts(self) -> t.List[Identifier]:
"""Return the parts of a column in order catalog, db, table, name."""
return [part for part in reversed(list(self.args.values())) if part]
def to_dot(self) -> Dot:
"""Converts the column into a dot expression."""
parts = self.parts
parent = self.parent
while parent:
if isinstance(parent, Dot):
parts.append(parent.expression)
parent = parent.parent
return Dot.build(parts)
class ColumnDef(Expression):
arg_types = {
@ -1033,6 +1054,113 @@ class Constraint(Expression):
class Delete(Expression):
arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False}
def delete(
self,
table: ExpOrStr,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Delete:
"""
Create a DELETE expression or replace the table on an existing DELETE expression.
Example:
>>> delete("tbl").sql()
'DELETE FROM tbl'
Args:
table: the table from which to delete.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Delete: the modified expression.
"""
return _apply_builder(
expression=table,
instance=self,
arg="this",
dialect=dialect,
into=Table,
copy=copy,
**opts,
)
def where(
self,
*expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Delete:
"""
Append to or set the WHERE expressions.
Example:
>>> delete("tbl").where("x = 'a' OR x < 'b'").sql()
"DELETE FROM tbl WHERE x = 'a' OR x < 'b'"
Args:
*expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append: if `True`, AND the new expressions to any existing expression.
Otherwise, this resets the expression.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Delete: the modified expression.
"""
return _apply_conjunction_builder(
*expressions,
instance=self,
arg="where",
append=append,
into=Where,
dialect=dialect,
copy=copy,
**opts,
)
def returning(
self,
expression: ExpOrStr,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Delete:
"""
Set the RETURNING expression. Not supported by all dialects.
Example:
>>> delete("tbl").returning("*", dialect="postgres").sql()
'DELETE FROM tbl RETURNING *'
Args:
expression: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect: the dialect used to parse the input expressions.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Delete: the modified expression.
"""
return _apply_builder(
expression=expression,
instance=self,
arg="returning",
prefix="RETURNING",
dialect=dialect,
copy=copy,
into=Returning,
**opts,
)
class Drop(Expression):
arg_types = {
@ -1824,7 +1952,7 @@ class Union(Subqueryable):
def select(
self,
*expressions: str | Expression,
*expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@ -2170,7 +2298,7 @@ class Select(Subqueryable):
def select(
self,
*expressions: str | Expression,
*expressions: ExpOrStr,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
@ -2875,6 +3003,20 @@ class Dot(Binary):
def name(self) -> str:
return self.expression.name
@classmethod
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
if len(expressions) < 2:
raise ValueError(f"Dot requires >= 2 expressions.")
a, b, *expressions = expressions
dot = Dot(this=a, expression=b)
for expression in expressions:
dot = Dot(this=dot, expression=expression)
return dot
class DPipe(Binary):
pass
@ -3049,7 +3191,7 @@ class TimeUnit(Expression):
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, Column):
if isinstance(unit, (Column, Literal)):
args["unit"] = Var(this=unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
@ -3261,6 +3403,10 @@ class Count(AggFunc):
arg_types = {"this": False}
class CountIf(AggFunc):
pass
class CurrentDate(Func):
arg_types = {"this": False}
@ -3407,6 +3553,10 @@ class Explode(Func):
pass
class ExponentialTimeDecayedAvg(AggFunc):
arg_types = {"this": True, "time": False, "decay": False}
class Floor(Func):
arg_types = {"this": True, "decimals": False}
@ -3420,10 +3570,18 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
class GroupUniqArray(AggFunc):
arg_types = {"this": True, "size": False}
class Hex(Func):
pass
class Histogram(AggFunc):
arg_types = {"this": True, "bins": False}
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@ -3493,7 +3651,11 @@ class Log10(Func):
class LogicalOr(AggFunc):
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
_sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"]
class LogicalAnd(AggFunc):
_sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"]
class Lower(Func):
@ -3561,6 +3723,7 @@ class Quantile(AggFunc):
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
class Quantiles(AggFunc):
arg_types = {"parameters": True, "expressions": True}
is_var_len_args = True
class QuantileIf(AggFunc):
@ -3830,7 +3993,7 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
def maybe_parse(
sql_or_expression: str | Expression,
sql_or_expression: ExpOrStr,
*,
into: t.Optional[IntoType] = None,
dialect: DialectType = None,
@ -4091,7 +4254,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select:
def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@ -4135,7 +4298,14 @@ def from_(*expressions, dialect=None, **opts) -> Select:
return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
def update(
table: str | Table,
properties: dict,
where: t.Optional[ExpOrStr] = None,
from_: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
**opts,
) -> Update:
"""
Creates an update statement.
@ -4144,18 +4314,18 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
Args:
*properties (Dict[str, Any]): dictionary of properties to set which are
*properties: dictionary of properties to set which are
auto converted to sql objects eg None -> NULL
where (str): sql conditional parsed into a WHERE statement
from_ (str): sql statement parsed into a FROM statement
dialect (str): the dialect used to parse the input expressions.
where: sql conditional parsed into a WHERE statement
from_: sql statement parsed into a FROM statement
dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update_expr.set(
"expressions",
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
@ -4163,21 +4333,27 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
],
)
if from_:
update.set(
update_expr.set(
"from",
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
)
if isinstance(where, Condition):
where = Where(this=where)
if where:
update.set(
update_expr.set(
"where",
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
return update
return update_expr
def delete(table, where=None, dialect=None, **opts) -> Delete:
def delete(
table: ExpOrStr,
where: t.Optional[ExpOrStr] = None,
returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
**opts,
) -> Delete:
"""
Builds a delete statement.
@ -4186,19 +4362,20 @@ def delete(table, where=None, dialect=None, **opts) -> Delete:
'DELETE FROM my_table WHERE id > 1'
Args:
where (str|Condition): sql conditional parsed into a WHERE statement
dialect (str): the dialect used to parse the input expressions.
where: sql conditional parsed into a WHERE statement
returning: sql conditional parsed into a RETURNING statement
dialect: the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Delete: the syntax tree for the DELETE statement.
"""
return Delete(
this=maybe_parse(table, into=Table, dialect=dialect, **opts),
where=Where(this=where)
if isinstance(where, Condition)
else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
)
delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts)
if where:
delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts)
if returning:
delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts)
return delete_expr
def condition(expression, dialect=None, **opts) -> Condition:
@ -4414,7 +4591,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
def alias_(
expression: str | Expression,
expression: ExpOrStr,
alias: str | Identifier,
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
@ -4516,7 +4693,7 @@ def column(
)
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@ -4595,7 +4772,7 @@ def values(
)
def var(name: t.Optional[str | Expression]) -> Var:
def var(name: t.Optional[ExpOrStr]) -> Var:
"""Build a SQL variable.
Example:
@ -4612,7 +4789,7 @@ def var(name: t.Optional[str | Expression]) -> Var:
The new variable node.
"""
if not name:
raise ValueError(f"Cannot convert empty name into var.")
raise ValueError("Cannot convert empty name into var.")
if isinstance(name, Expression):
name = name.name
@ -4682,7 +4859,7 @@ def convert(value) -> Expression:
raise ValueError(f"Cannot convert {value}")
def replace_children(expression, fun):
def replace_children(expression, fun, *args, **kwargs):
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
@ -4694,7 +4871,7 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
for child_node in ensure_collection(fun(cn)):
for child_node in ensure_collection(fun(cn, *args, **kwargs)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k

View file

@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@ -25,8 +25,7 @@ class Generator:
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
identify (bool): if set to True all identifiers will be delimited by the corresponding
character.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
identifier_escape (str): specifies an identifier escape character. Default: ".
@ -57,10 +56,10 @@ class Generator:
TRANSFORMS = {
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", e.this, e.expression, e.args.get("unit")
"DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.TsOrDsAdd: lambda self, e: self.func(
"TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit")
"TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit"))
),
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
@ -736,7 +735,7 @@ class Generator:
text = expression.name
text = text.lower() if self.normalize else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or self.identify:
if expression.args.get("quoted") or should_identify(text, self.identify):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@ -1176,6 +1175,22 @@ class Generator:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression: exp.Set) -> str:
expressions = (
f" {self.expressions(expression, flat=True)}" if expression.expressions else ""
)
return f"SET{expressions}"
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
@ -1359,8 +1374,8 @@ class Generator:
sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
alias,
self.expressions(expression, key="pivots", sep=" "),
)
return self.prepend_ctes(expression, sql)
@ -1668,7 +1683,7 @@ class Generator:
expression_sql = self.sql(expression, "expression")
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
def transaction_sql(self, *_) -> str:
def transaction_sql(self, expression: exp.Transaction) -> str:
return "BEGIN"
def commit_sql(self, expression: exp.Commit) -> str:

View file

@ -403,3 +403,20 @@ def first(it: t.Iterable[T]) -> T:
Useful for sets.
"""
return next(i for i in it)
def should_identify(text: str, identify: str | bool) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
identify: "always" | True - always returns true, "safe" - true if no upper case
Returns:
Whether or not a string should be identified.
"""
if identify is True or identify == "always":
return True
if identify == "safe":
return not any(char.isupper() for char in text)
return False

View file

@ -1,9 +1,12 @@
from __future__ import annotations
import itertools
from sqlglot import exp
from sqlglot.helper import should_identify
def canonicalize(expression: exp.Expression) -> exp.Expression:
def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
Args:
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize)
exp.replace_children(expression, canonicalize, identify=identify)
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)
if isinstance(expression, exp.Identifier):
expression.set("quoted", True)
if should_identify(expression.this, identify):
expression.set("quoted", True)
return expression
@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Connector):
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
elif isinstance(expression, (exp.Where, exp.Having)):
_replace_int_predicate(expression.this)
return expression
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
if (
@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None:
cast = exp.Cast(this=node.copy(), to=data_type)
cast.type = data_type
node.replace(cast)
def _replace_int_predicate(expression: exp.Expression) -> None:
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))

View file

@ -1,7 +1,6 @@
from collections import defaultdict
from sqlglot import alias, exp
from sqlglot.helper import flatten
from sqlglot.optimizer.qualify_columns import Resolver
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema):
else:
order_refs = set()
new_selections = defaultdict(list)
new_selections = []
removed = False
star = False
for selection in scope.selects:
name = selection.alias_or_name
if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs:
new_selections[name].append(selection)
new_selections.append(selection)
else:
if selection.is_star:
star = True
@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema):
if star:
resolver = Resolver(scope, schema)
names = {s.alias_or_name for s in new_selections}
for name in sorted(parent_selections):
if name not in new_selections:
new_selections[name].append(
alias(exp.column(name, table=resolver.get_table(name)), name)
)
if name not in names:
new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name))
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections[""].append(DEFAULT_SELECTION())
new_selections.append(DEFAULT_SELECTION())
scope.expression.select(*flatten(new_selections.values()), append=False, copy=False)
scope.expression.select(*new_selections, append=False, copy=False)
if removed:
scope.clear_cache()

View file

@ -37,6 +37,7 @@ def qualify_columns(expression, schema):
_qualify_outputs(scope)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
return expression
@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources:
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
root, *parts = column.parts
if root.name in scope.sources:
# struct is already qualified, but we still need to change the AST representation
column_table = root
root, *parts = parts
else:
column_table = resolver.get_table(root.name)
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
@ -373,10 +389,14 @@ class Resolver:
if isinstance(node, exp.Subqueryable):
while node and node.alias != table_name:
node = node.parent
node_alias = node.args.get("alias")
if node_alias:
return node_alias.this
return exp.to_identifier(table_name)
return exp.to_identifier(
table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
)
@property
def all_columns(self):

View file

@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)
for source in scope.sources.values():
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
identifier = isinstance(source.this, exp.Identifier)
if identifier:
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
source.set("db", exp.to_identifier(db))
if not source.args.get("catalog"):
@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
source.this if identifier else next_name(),
name if name else next_name(),
table=True,
)
)

View file

@ -4,6 +4,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
class ScopeType(Enum):
@ -293,6 +294,8 @@ class Scope:
result = {}
for name, node in referenced_names:
if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
result[name] = (node, self.sources[name])
@ -594,6 +597,8 @@ def _traverse_tables(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
sources[source_name] = expression
continue

View file

@ -96,6 +96,7 @@ class Parser(metaclass=_Parser):
NO_PAREN_FUNCTIONS = {
TokenType.CURRENT_DATE: exp.CurrentDate,
TokenType.CURRENT_DATETIME: exp.CurrentDate,
TokenType.CURRENT_TIME: exp.CurrentTime,
TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp,
}
@ -198,7 +199,6 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT,
TokenType.CURRENT_TIME,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DESCRIBE,
@ -370,8 +370,9 @@ class Parser(metaclass=_Parser):
LAMBDAS = {
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
this=self._parse_conjunction().transform(
self._replace_lambda, {node.name for node in expressions}
this=self._replace_lambda(
self._parse_conjunction(),
{node.name for node in expressions},
),
expressions=expressions,
),
@ -441,6 +442,7 @@ class Parser(metaclass=_Parser):
exp.With: lambda self: self._parse_with(),
exp.Window: lambda self: self._parse_named_window(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@ -460,6 +462,7 @@ class Parser(metaclass=_Parser):
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(
@ -656,15 +659,15 @@ class Parser(metaclass=_Parser):
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"TRY_CAST": lambda self: self._parse_cast(False),
"STRING_AGG": lambda self: self._parse_string_agg(),
"TRY_CONVERT": lambda self: self._parse_convert(False),
}
QUERY_MODIFIER_PARSERS = {
@ -684,13 +687,28 @@ class Parser(metaclass=_Parser):
"sample": lambda self: self._parse_table_sample(as_modifier=True),
}
SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
"TRANSACTION": lambda self: self._parse_set_transaction(),
}
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
SET_PARSERS: t.Dict[str, t.Callable] = {}
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
"ISOLATION LEVEL READ UNCOMMITTED",
"ISOLATION LEVEL SERIALIZABLE",
"READ WRITE",
"READ ONLY",
}
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
@ -1775,11 +1793,12 @@ class Parser(metaclass=_Parser):
self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
alias = (
self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
or self._parse_string_as_identifier()
)
index = self._index
index = self._index
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_function_parameter)
self._match_r_paren() if columns else self._retreat(index)
@ -2046,7 +2065,12 @@ class Parser(metaclass=_Parser):
def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
catalog = None
db = None
table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False)
table = (
(not schema and self._parse_function())
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
)
while self._match(TokenType.DOT):
if catalog:
@ -2085,6 +2109,8 @@ class Parser(metaclass=_Parser):
subquery = self._parse_select(table=True)
if subquery:
if not subquery.args.get("pivots"):
subquery.set("pivots", self._parse_pivots())
return subquery
this = self._parse_table_parts(schema=schema)
@ -3370,9 +3396,9 @@ class Parser(metaclass=_Parser):
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
this = self.expression(exp.Filter, this=this, expression=where)
if self._match_pair(TokenType.FILTER, TokenType.L_PAREN):
this = self.expression(exp.Filter, this=this, expression=self._parse_where())
self._match_r_paren()
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
@ -3504,6 +3530,9 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]:
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
def _parse_number(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NUMBER):
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
@ -3778,23 +3807,6 @@ class Parser(metaclass=_Parser):
)
return self._parse_as_command(start)
def _parse_show(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
def _default_parse_set_item(self) -> exp.Expression:
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)
def _parse_set_item(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
return parser(self) if parser else self._default_parse_set_item()
def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO)
target = self._parse_table()
@ -3861,8 +3873,71 @@ class Parser(metaclass=_Parser):
expressions=whens,
)
def _parse_show(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
def _parse_set_item_assignment(
self, kind: t.Optional[str] = None
) -> t.Optional[exp.Expression]:
index = self._index
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
if not self._match_texts(("=", "TO")):
self._retreat(index)
return None
right = self._parse_statement() or self._parse_id_var()
this = self.expression(
exp.EQ,
this=left,
expression=right,
)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
self._match_text_seq("TRANSACTION")
characteristics = self._parse_csv(
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
)
return self.expression(
exp.SetItem,
expressions=characteristics,
kind="TRANSACTION",
**{"global": global_}, # type: ignore
)
def _parse_set_item(self) -> t.Optional[exp.Expression]:
parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
index = self._index
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
if self._curr:
self._retreat(index)
return self._parse_as_command(self._prev)
return set_
def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]:
for option in options:
if self._match_text_seq(*option.split(" ")):
return exp.Var(this=option)
return None
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
@ -3874,6 +3949,9 @@ class Parser(metaclass=_Parser):
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:
if not self._curr:
return None
index = self._index
this = []
while True:
@ -3973,7 +4051,16 @@ class Parser(metaclass=_Parser):
return this
def _replace_lambda(self, node, lambda_variables):
if isinstance(node, exp.Column):
if node.name in lambda_variables:
return node.this
for column in node.find_all(exp.Column):
if column.parts[0].name in lambda_variables:
dot_or_id = column.to_dot() if column.table else column.this
parent = column.parent
while isinstance(parent, exp.Dot):
if not isinstance(parent.parent, exp.Dot):
parent.replace(dot_or_id)
break
parent = parent.parent
else:
column.replace(dot_or_id)
return node

View file

@ -502,6 +502,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CUBE": TokenType.CUBE,
"CURRENT_DATE": TokenType.CURRENT_DATE,
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
"DATABASE": TokenType.DATABASE,
"DEFAULT": TokenType.DEFAULT,
@ -725,7 +726,6 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.COMMAND,
TokenType.EXECUTE,
TokenType.FETCH,
TokenType.SET,
TokenType.SHOW,
}
@ -851,8 +851,10 @@ class Tokenizer(metaclass=_Tokenizer):
# If we have either a semicolon or a begin token before the command's token, we'll parse
# whatever follows the command's token as a string
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS
if (
token_type in self.COMMANDS
and self._peek != ";"
and (len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS)
):
start = self._current
tokens = len(self.tokens)

View file

@ -2,13 +2,12 @@ from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
from sqlglot import expressions as exp
def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
@ -61,8 +60,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
distinct_cols = expression.args["distinct"].args["on"].expressions
expression.args["distinct"].pop()
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
@ -71,14 +69,49 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
)
order = expression.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window.set("order", order.pop().copy())
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
return expression
def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
"""
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as
projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
otherwise we won't be able to refer to it in the outer query's WHERE clause.
"""
if isinstance(expression, exp.Select) and expression.args.get("qualify"):
taken = set(expression.named_selects)
for select in expression.selects:
if not select.alias_or_name:
alias = find_new_name(taken, "_c")
select.replace(exp.alias_(select.copy(), alias))
taken.add(alias)
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
qualify_filters = expression.args["qualify"].pop().this
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr.copy(), alias), copy=False)
expr.replace(exp.column(alias))
elif expr.name not in expression.named_selects:
expression.select(expr.copy(), copy=False)
return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters)
return expression
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
"""
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
@ -139,6 +172,7 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))}
REMOVE_PRECISION_PARAMETERIZED_TYPES = {
exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
}