1
0
Fork 0

Merging upstream version 25.18.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:52:55 +01:00
parent 75ba8bde98
commit f2390c2221
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
81 changed files with 34133 additions and 33517 deletions

View file

@ -20,6 +20,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
timestamptrunc_sql,
unit_to_var,
trim_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import is_int, seq_get
@ -875,6 +876,7 @@ class ClickHouse(Dialect):
exp.SHA2: sha256_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.TimestampTrunc: timestamptrunc_sql(zone=True),
exp.Trim: trim_sql,
exp.Variance: rename_func("varSamp"),
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.Stddev: rename_func("stddevSamp"),
@ -890,6 +892,7 @@ class ClickHouse(Dialect):
# There's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
ON_CLUSTER_TARGETS = {
"SCHEMA", # Transpiled CREATE SCHEMA may have OnCluster property set
"DATABASE",
"TABLE",
"VIEW",

View file

@ -11,7 +11,7 @@ from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.time import TIMEZONES, format_time, subsecond_precision
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
@ -1243,13 +1243,24 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
)
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
datatype = (
def timestrtotime_sql(
self: Generator,
expression: exp.TimeStrToTime,
include_precision: bool = False,
) -> str:
datatype = exp.DataType.build(
exp.DataType.Type.TIMESTAMPTZ
if expression.args.get("zone")
else exp.DataType.Type.TIMESTAMP
)
if isinstance(expression.this, exp.Literal) and include_precision:
precision = subsecond_precision(expression.this.name)
if precision > 0:
datatype = exp.DataType.build(
datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))]
)
return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
@ -1295,7 +1306,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str:
collation = self.sql(expression, "collation")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
if not remove_chars and not collation:
if not remove_chars:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""

View file

@ -33,6 +33,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
unit_to_var,
unit_to_str,
sha256_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -41,6 +42,14 @@ DATETIME_DELTA = t.Union[
exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.DateSub, exp.DatetimeSub
]
WINDOW_FUNCS_WITH_IGNORE_NULLS = (
exp.FirstValue,
exp.LastValue,
exp.Lag,
exp.Lead,
exp.NthValue,
)
def _date_delta_sql(self: DuckDB.Generator, expression: DATETIME_DELTA) -> str:
this = expression.this
@ -376,6 +385,7 @@ class DuckDB(Dialect):
}
FUNCTIONS.pop("DATE_SUB")
FUNCTIONS.pop("GLOB")
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")
@ -539,6 +549,7 @@ class DuckDB(Dialect):
exp.ReturnsProperty: lambda self, e: "TABLE" if isinstance(e.this, exp.Schema) else "",
exp.Rand: rename_func("RANDOM"),
exp.SafeDivide: no_safe_divide_sql,
exp.SHA2: sha256_sql,
exp.Split: rename_func("STR_SPLIT"),
exp.SortArray: _sort_array_sql,
exp.StrPosition: str_position_sql,
@ -546,6 +557,7 @@ class DuckDB(Dialect):
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
exp.Struct: _struct_sql,
exp.Transform: rename_func("LIST_TRANSFORM"),
exp.TimeAdd: _date_delta_sql,
exp.Time: no_time_sql,
exp.TimeDiff: _timediff_sql,
@ -753,7 +765,6 @@ class DuckDB(Dialect):
def tablesample_sql(
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
if not isinstance(expression.parent, exp.Select):
@ -769,9 +780,7 @@ class DuckDB(Dialect):
)
expression.set("method", exp.var("RESERVOIR"))
return super().tablesample_sql(
expression, sep=sep, tablesample_keyword=tablesample_keyword
)
return super().tablesample_sql(expression, tablesample_keyword=tablesample_keyword)
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
@ -910,3 +919,11 @@ class DuckDB(Dialect):
return self.sql(select)
return super().unnest_sql(expression)
def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
if isinstance(expression.this, WINDOW_FUNCS_WITH_IGNORE_NULLS):
# DuckDB should render IGNORE NULLS only for the general-purpose
# window functions that accept it e.g. FIRST_VALUE(... IGNORE NULLS) OVER (...)
return super().ignorenulls_sql(expression)
return self.sql(expression, "this")

View file

@ -436,6 +436,14 @@ class Hive(Dialect):
self._match(TokenType.R_BRACE)
return self.expression(exp.Parameter, this=this, expression=expression)
def _to_prop_eq(self, expression: exp.Expression, index: int) -> exp.Expression:
if isinstance(expression, exp.Column):
key = expression.this
else:
key = exp.to_identifier(f"col{index + 1}")
return self.expression(exp.PropertyEQ, this=key, expression=expression)
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False

View file

@ -24,6 +24,8 @@ from sqlglot.dialects.dialect import (
rename_func,
strposition_to_locate_sql,
unit_to_var,
trim_sql,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -95,21 +97,6 @@ def _str_to_date_sql(
return self.func("STR_TO_DATE", expression.this, self.format_time(expression))
def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
# Use TRIM/LTRIM/RTRIM syntax if the expression isn't mysql-specific
if not remove_chars:
return self.trim_sql(expression)
trim_type = f"{trim_type} " if trim_type else ""
remove_chars = f"{remove_chars} " if remove_chars else ""
from_part = "FROM " if trim_type or remove_chars else ""
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _unix_to_time_sql(self: MySQL.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = expression.this
@ -348,6 +335,7 @@ class MySQL(Dialect):
"VALUES": lambda self: self.expression(
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
),
"JSON_VALUE": lambda self: self._parse_json_value(),
}
STATEMENT_PARSERS = {
@ -677,6 +665,33 @@ class MySQL(Dialect):
return self.expression(exp.GroupConcat, this=this, separator=separator)
def _parse_json_value(self) -> exp.JSONValue:
def _parse_on_options() -> t.Optional[exp.Expression] | str:
if self._match_texts(("NULL", "ERROR")):
value = self._prev.text.upper()
else:
value = self._match(TokenType.DEFAULT) and self._parse_bitwise()
self._match_text_seq("ON")
self._match_texts(("EMPTY", "ERROR"))
return value
this = self._parse_bitwise()
self._match(TokenType.COMMA)
path = self._parse_bitwise()
returning = self._match(TokenType.RETURNING) and self._parse_type()
return self.expression(
exp.JSONValue,
this=this,
path=self.dialect.to_json_path(path),
returning=returning,
on_error=_parse_on_options(),
on_empty=_parse_on_options(),
)
class Generator(generator.Generator):
INTERVAL_ALLOWS_PLURAL_FORM = False
LOCKING_READS_SUPPORTED = True
@ -742,13 +757,15 @@ class MySQL(Dialect):
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(
exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True)
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(
self,
e,
include_precision=not e.args.get("zone"),
),
exp.TimeToStr: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
),
exp.Trim: _trim_sql,
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: date_add_sql("ADD"),
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
@ -1224,3 +1241,7 @@ class MySQL(Dialect):
dt = expression.args.get("timestamp")
return self.func("CONVERT_TZ", dt, from_tz, to_tz)
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
self.unsupported("AT TIME ZONE is not supported by MySQL")
return self.sql(expression.this)

View file

@ -33,6 +33,15 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)
def _trim_sql(self: Oracle.Generator, expression: exp.Trim) -> str:
position = expression.args.get("position")
if position and position.upper() in ("LEADING", "TRAILING"):
return self.trim_sql(expression)
return trim_sql(self, expression)
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
@ -267,12 +276,12 @@ class Oracle(Dialect):
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e),
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Trim: trim_sql,
exp.Trim: _trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}

View file

@ -93,7 +93,9 @@ def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)).and_(
exp.Is(this=seq_get(args, 0), expression=exp.null()).not_()
)
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
return exp.If(this=cond, true=true, false=false)

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
@ -7,6 +9,7 @@ from sqlglot.dialects.dialect import (
build_timestamp_trunc,
rename_func,
unit_to_str,
inline_array_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@ -26,6 +29,19 @@ class StarRocks(MySQL):
"REGEXP": exp.RegexpLike.from_arg_list,
}
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
unnest = super()._parse_unnest(with_alias=with_alias)
if unnest:
alias = unnest.args.get("alias")
if alias and not alias.args.get("columns"):
# Starrocks defaults to naming the UNNEST column as "unnest"
# if it's not otherwise specified
alias.set("columns", [exp.to_identifier("unnest")])
return unnest
class Generator(MySQL.Generator):
CAST_MAPPING = {}
@ -38,6 +54,7 @@ class StarRocks(MySQL):
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", unit_to_str(e), e.this, e.expression

View file

@ -278,7 +278,6 @@ class Teradata(Dialect):
def tablesample_sql(
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql, timestrtotime_sql
from sqlglot.dialects.presto import Presto
@ -21,6 +21,7 @@ class Trino(Presto):
exp.ArraySum: lambda self,
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
exp.Trim: trim_sql,
}

View file

@ -351,12 +351,13 @@ def _timestrtotime_sql(self: TSQL.Generator, expression: exp.TimeStrToTime):
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
SUPPORTS_SEMI_ANTI_JOIN = False
LOG_BASE_FIRST = False
TYPED_DIVISION = True
CONCAT_COALESCE = True
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
TIME_MAPPING = {
"year": "%Y",
@ -395,7 +396,7 @@ class TSQL(Dialect):
"HH": "%H",
"H": "%-H",
"h": "%-I",
"S": "%f",
"ffffff": "%f",
"yyyy": "%Y",
"yy": "%y",
}
@ -983,7 +984,9 @@ class TSQL(Dialect):
return super().setitem_sql(expression)
def boolean_sql(self, expression: exp.Boolean) -> str:
if type(expression.parent) in BIT_TYPES:
if type(expression.parent) in BIT_TYPES or isinstance(
expression.find_ancestor(exp.Values, exp.Select), exp.Values
):
return "1" if expression.this else "0"
return "(1 = 1)" if expression.this else "(1 = 0)"

View file

@ -2172,6 +2172,7 @@ class Insert(DDL, DML):
"stored": False,
"partition": False,
"settings": False,
"source": False,
}
def with_(
@ -2280,6 +2281,18 @@ class Group(Expression):
}
class Cube(Expression):
arg_types = {"expressions": False}
class Rollup(Expression):
arg_types = {"expressions": False}
class GroupingSets(Expression):
arg_types = {"expressions": True}
class Lambda(Expression):
arg_types = {"this": True, "expressions": True}
@ -3074,6 +3087,7 @@ class Table(Expression):
"partition": False,
"changes": False,
"rows_from": False,
"sample": False,
}
@property
@ -3846,7 +3860,6 @@ class Subquery(DerivedTable, Query):
class TableSample(Expression):
arg_types = {
"this": False,
"expressions": False,
"method": False,
"bucket_numerator": False,
@ -5441,6 +5454,11 @@ class IsInf(Func):
_sql_names = ["IS_INF", "ISINF"]
# https://www.postgresql.org/docs/current/functions-json.html
class JSON(Expression):
arg_types = {"this": False, "with": False, "unique": False}
class JSONPath(Expression):
arg_types = {"expressions": True}
@ -5553,6 +5571,17 @@ class JSONSchema(Expression):
arg_types = {"expressions": True}
# https://dev.mysql.com/doc/refman/8.4/en/json-search-functions.html#function_json-value
class JSONValue(Expression):
arg_types = {
"this": True,
"path": True,
"returning": False,
"on_empty": False,
"on_error": False,
}
# # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html
class JSONTable(Func):
arg_types = {

View file

@ -1669,7 +1669,10 @@ class Generator(metaclass=_Generator):
settings = self.sql(expression, "settings")
settings = f" {settings}" if settings else ""
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{partition_by}{settings}{where}{expression_sql}"
source = self.sql(expression, "source")
source = f"TABLE {source}" if source else ""
sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{partition_by}{settings}{where}{expression_sql}{source}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@ -1764,6 +1767,15 @@ class Generator(metaclass=_Generator):
version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
sample = self.sql(expression, "sample")
if self.dialect.ALIAS_POST_TABLESAMPLE:
sample_pre_alias = sample
sample_post_alias = ""
else:
sample_pre_alias = ""
sample_post_alias = sample
hints = self.expressions(expression, key="hints", sep=" ")
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
@ -1794,23 +1806,13 @@ class Generator(metaclass=_Generator):
if rows_from:
table = f"ROWS FROM {self.wrap(rows_from)}"
return f"{only}{table}{changes}{partition}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
return f"{only}{table}{changes}{partition}{version}{file_format}{sample_pre_alias}{alias}{hints}{pivots}{sample_post_alias}{joins}{laterals}{ordinality}"
def tablesample_sql(
self,
expression: exp.TableSample,
sep: str = " AS ",
tablesample_keyword: t.Optional[str] = None,
) -> str:
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
alias = ""
method = self.sql(expression, "method")
method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
@ -1833,9 +1835,7 @@ class Generator(metaclass=_Generator):
if self.TABLESAMPLE_REQUIRES_PARENS:
expr = f"({expr})"
return (
f"{this} {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}{alias}"
)
return f" {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}"
def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
@ -1946,6 +1946,18 @@ class Generator(metaclass=_Generator):
def from_sql(self, expression: exp.From) -> str:
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"
def groupingsets_sql(self, expression: exp.GroupingSets) -> str:
grouping_sets = self.expressions(expression, indent=False)
return f"GROUPING SETS {self.wrap(grouping_sets)}"
def rollup_sql(self, expression: exp.Rollup) -> str:
expressions = self.expressions(expression, indent=False)
return f"ROLLUP {self.wrap(expressions)}" if expressions else "WITH ROLLUP"
def cube_sql(self, expression: exp.Cube) -> str:
expressions = self.expressions(expression, indent=False)
return f"CUBE {self.wrap(expressions)}" if expressions else "WITH CUBE"
def group_sql(self, expression: exp.Group) -> str:
group_by_all = expression.args.get("all")
if group_by_all is True:
@ -1957,34 +1969,23 @@ class Generator(metaclass=_Generator):
group_by = self.op_expressions(f"GROUP BY{modifier}", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = expression.args.get("cube", [])
if seq_get(cube, 0) is True:
return f"{group_by}{self.seg('WITH CUBE')}"
else:
cube_sql = self.expressions(expression, key="cube", indent=False)
cube_sql = f"{self.seg('CUBE')} {self.wrap(cube_sql)}" if cube_sql else ""
rollup = expression.args.get("rollup", [])
if seq_get(rollup, 0) is True:
return f"{group_by}{self.seg('WITH ROLLUP')}"
else:
rollup_sql = self.expressions(expression, key="rollup", indent=False)
rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else ""
grouping_sets = self.expressions(expression, key="grouping_sets")
cube = self.expressions(expression, key="cube")
rollup = self.expressions(expression, key="rollup")
groupings = csv(
grouping_sets,
cube_sql,
rollup_sql,
self.seg(grouping_sets) if grouping_sets else "",
self.seg(cube) if cube else "",
self.seg(rollup) if rollup else "",
self.seg("WITH TOTALS") if expression.args.get("totals") else "",
sep=self.GROUPINGS_SEP,
)
if expression.args.get("expressions") and groupings:
if (
expression.expressions
and groupings
and groupings.strip() not in ("WITH CUBE", "WITH ROLLUP")
):
group_by = f"{group_by}{self.GROUPINGS_SEP}"
return f"{group_by}{groupings}"
@ -2446,6 +2447,13 @@ class Generator(metaclass=_Generator):
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
sample = self.sql(expression, "sample")
if self.dialect.ALIAS_POST_TABLESAMPLE and sample:
alias = f"{sample}{alias}"
# Set to None so it's not generated again by self.query_modifiers()
expression.set("sample", None)
pivots = self.expressions(expression, key="pivots", sep="", flat=True)
sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)
@ -2648,11 +2656,13 @@ class Generator(metaclass=_Generator):
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
return self.func("LTRIM", expression.this)
func_name = "LTRIM"
elif trim_type == "TRAILING":
return self.func("RTRIM", expression.this)
func_name = "RTRIM"
else:
return self.func("TRIM", expression.this, expression.expression)
func_name = "TRIM"
return self.func(func_name, expression.this, expression.expression)
def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]:
args = expression.expressions
@ -2889,7 +2899,12 @@ class Generator(metaclass=_Generator):
return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
return self.func(self.sql(expression, "this"), *expression.expressions)
# We don't normalize qualified functions such as a.b.foo(), because they can be case-sensitive
parent = expression.parent
is_qualified = isinstance(parent, exp.Dot) and expression is parent.expression
return self.func(
self.sql(expression, "this"), *expression.expressions, normalize=not is_qualified
)
def paren_sql(self, expression: exp.Paren) -> str:
sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
@ -3398,8 +3413,10 @@ class Generator(metaclass=_Generator):
*args: t.Optional[exp.Expression | str],
prefix: str = "(",
suffix: str = ")",
normalize: bool = True,
) -> str:
return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
name = self.normalize_func(name) if normalize else name
return f"{name}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(
@ -4137,3 +4154,36 @@ class Generator(metaclass=_Generator):
expr = exp.AtTimeZone(this=timestamp, zone=target_tz)
return self.sql(expr)
def json_sql(self, expression: exp.JSON) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
_with = expression.args.get("with")
if _with is None:
with_sql = ""
elif not _with:
with_sql = " WITHOUT"
else:
with_sql = " WITH"
unique_sql = " UNIQUE KEYS" if expression.args.get("unique") else ""
return f"JSON{this}{with_sql}{unique_sql}"
def jsonvalue_sql(self, expression: exp.JSONValue) -> str:
def _generate_on_options(arg: t.Any) -> str:
return arg if isinstance(arg, str) else f"DEFAULT {self.sql(arg)}"
path = self.sql(expression, "path")
returning = self.sql(expression, "returning")
returning = f" RETURNING {returning}" if returning else ""
on_empty = expression.args.get("on_empty")
on_empty = f" {_generate_on_options(on_empty)} ON EMPTY" if on_empty else ""
on_error = expression.args.get("on_error")
on_error = f" {_generate_on_options(on_error)} ON ERROR" if on_error else ""
return self.func("JSON_VALUE", expression.this, f"{path}{returning}{on_empty}{on_error}")

View file

@ -84,7 +84,7 @@ def qualify_tables(
for name, source in scope.sources.items():
if isinstance(source, exp.Table):
pivots = pivots = source.args.get("pivots")
pivots = source.args.get("pivots")
if not source.alias:
# Don't add the pivot's alias to the pivoted table, use the table's name instead
if pivots and pivots[0].alias == name:

View file

@ -267,13 +267,11 @@ def flatten(expression):
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if left == right:
if isinstance(expression, exp.Xor):
return exp.false()
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
return exp.false()
if is_zero(left) or is_zero(right):
return exp.false()
if is_null(left) or is_null(right):
return exp.null()
if always_true(left) and always_true(right):
@ -286,12 +284,10 @@ def simplify_connectors(expression, root=True):
elif isinstance(expression, exp.Or):
if always_true(left) or always_true(right):
return exp.true()
if is_false(left) and is_false(right):
return exp.false()
if (
(is_null(left) and is_null(right))
or (is_null(left) and is_false(right))
or (is_false(left) and is_null(right))
or (is_null(left) and always_false(right))
or (always_false(left) and is_null(right))
):
return exp.null()
if is_false(left):
@ -299,6 +295,9 @@ def simplify_connectors(expression, root=True):
if is_false(right):
return left
return _simplify_comparison(expression, left, right, or_=True)
elif isinstance(expression, exp.Xor):
if left == right:
return exp.false()
if isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_connectors, root)
@ -1108,13 +1107,17 @@ def remove_where_true(expression):
def always_true(expression):
return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
expression, exp.Literal
return (isinstance(expression, exp.Boolean) and expression.this) or (
isinstance(expression, exp.Literal) and not is_zero(expression)
)
def always_false(expression):
return is_false(expression) or is_null(expression)
return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):

View file

@ -140,6 +140,14 @@ def build_convert_timezone(
return exp.ConvertTimezone.from_arg_list(args)
def build_trim(args: t.List, is_left: bool = True):
return exp.Trim(
this=seq_get(args, 0),
expression=seq_get(args, 1),
position="LEADING" if is_left else "TRAILING",
)
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@ -200,9 +208,11 @@ class Parser(metaclass=_Parser):
"LOWER": build_lower,
"LPAD": lambda args: build_pad(args),
"LEFTPAD": lambda args: build_pad(args),
"LTRIM": lambda args: build_trim(args),
"MOD": build_mod,
"RPAD": lambda args: build_pad(args, is_left=False),
"RIGHTPAD": lambda args: build_pad(args, is_left=False),
"RPAD": lambda args: build_pad(args, is_left=False),
"RTRIM": lambda args: build_trim(args, is_left=False),
"SCOPE_RESOLUTION": lambda args: exp.ScopeResolution(expression=seq_get(args, 0))
if len(args) != 2
else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)),
@ -1242,6 +1252,8 @@ class Parser(metaclass=_Parser):
COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"}
IS_JSON_PREDICATE_KIND = {"VALUE", "SCALAR", "ARRAY", "OBJECT"}
STRICT_CAST = True
PREFIXED_PIVOT_COLUMNS = False
@ -2557,6 +2569,7 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
alternative=alternative,
ignore=ignore,
source=self._match(TokenType.TABLE) and self._parse_table(),
)
def _parse_kill(self) -> exp.Kill:
@ -2973,6 +2986,7 @@ class Parser(metaclass=_Parser):
this=this,
pivots=self._parse_pivots(),
alias=self._parse_table_alias() if parse_alias else None,
sample=self._parse_table_sample(),
)
def _implicit_unnests_to_explicit(self, this: E) -> E:
@ -3543,7 +3557,7 @@ class Parser(metaclass=_Parser):
this.set("version", version)
if self.dialect.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
this.set("sample", self._parse_table_sample())
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
if alias:
@ -3560,11 +3574,7 @@ class Parser(metaclass=_Parser):
this.set("pivots", self._parse_pivots())
if not self.dialect.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
if table_sample:
table_sample.set("this", this)
this = table_sample
this.set("sample", self._parse_table_sample())
if joins:
for join in self._parse_joins():
@ -3907,48 +3917,50 @@ class Parser(metaclass=_Parser):
elements["all"] = False
while True:
expressions = self._parse_csv(
lambda: None
if self._match_set((TokenType.CUBE, TokenType.ROLLUP), advance=False)
else self._parse_assignment()
)
if expressions:
elements["expressions"].extend(expressions)
grouping_sets = self._parse_grouping_sets()
if grouping_sets:
elements["grouping_sets"].extend(grouping_sets)
rollup = None
cube = None
totals = None
index = self._index
with_ = self._match(TokenType.WITH)
elements["expressions"].extend(
self._parse_csv(
lambda: None
if self._match_set((TokenType.CUBE, TokenType.ROLLUP), advance=False)
else self._parse_assignment()
)
)
before_with_index = self._index
with_prefix = self._match(TokenType.WITH)
if self._match(TokenType.ROLLUP):
rollup = with_ or self._parse_wrapped_csv(self._parse_column)
elements["rollup"].extend(ensure_list(rollup))
if self._match(TokenType.CUBE):
cube = with_ or self._parse_wrapped_csv(self._parse_column)
elements["cube"].extend(ensure_list(cube))
if self._match_text_seq("TOTALS"):
totals = True
elements["rollup"].append(
self._parse_cube_or_rollup(exp.Rollup, with_prefix=with_prefix)
)
elif self._match(TokenType.CUBE):
elements["cube"].append(
self._parse_cube_or_rollup(exp.Cube, with_prefix=with_prefix)
)
elif self._match(TokenType.GROUPING_SETS):
elements["grouping_sets"].append(
self.expression(
exp.GroupingSets,
expressions=self._parse_wrapped_csv(self._parse_grouping_set),
)
)
elif self._match_text_seq("TOTALS"):
elements["totals"] = True # type: ignore
if not (grouping_sets or rollup or cube or totals):
if with_:
self._retreat(index)
if before_with_index <= self._index <= before_with_index + 1:
self._retreat(before_with_index)
break
if index == self._index:
break
return self.expression(exp.Group, **elements) # type: ignore
def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.GROUPING_SETS):
return None
return self._parse_wrapped_csv(self._parse_grouping_set)
def _parse_cube_or_rollup(self, kind: t.Type[E], with_prefix: bool = False) -> E:
return self.expression(
kind, expressions=[] if with_prefix else self._parse_wrapped_csv(self._parse_column)
)
def _parse_grouping_set(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
@ -4282,10 +4294,26 @@ class Parser(metaclass=_Parser):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_bitwise())
expression = self._parse_null() or self._parse_boolean()
if not expression:
self._retreat(index)
return None
if self._match(TokenType.JSON):
kind = self._match_texts(self.IS_JSON_PREDICATE_KIND) and self._prev.text.upper()
if self._match_text_seq("WITH"):
_with = True
elif self._match_text_seq("WITHOUT"):
_with = False
else:
_with = None
unique = self._match(TokenType.UNIQUE)
self._match_text_seq("KEYS")
expression: t.Optional[exp.Expression] = self.expression(
exp.JSON, **{"this": kind, "with": _with, "unique": unique}
)
else:
expression = self._parse_primary() or self._parse_null()
if not expression:
self._retreat(index)
return None
this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this
@ -5087,10 +5115,13 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
def _to_prop_eq(self, expression: exp.Expression, index: int) -> exp.Expression:
return expression
def _kv_to_prop_eq(self, expressions: t.List[exp.Expression]) -> t.List[exp.Expression]:
transformed = []
for e in expressions:
for index, e in enumerate(expressions):
if isinstance(e, self.KEY_VALUE_DEFINITIONS):
if isinstance(e, exp.Alias):
e = self.expression(exp.PropertyEQ, this=e.args.get("alias"), expression=e.this)
@ -5102,6 +5133,8 @@ class Parser(metaclass=_Parser):
if isinstance(e.this, exp.Column):
e.this.replace(e.this.this)
else:
e = self._to_prop_eq(e, index)
transformed.append(e)

View file

@ -1,4 +1,5 @@
import typing as t
import datetime
# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
@ -661,3 +662,26 @@ TIMEZONES = {
"Zulu",
)
}
def subsecond_precision(timestamp_literal: str) -> int:
"""
Given an ISO-8601 timestamp literal, eg '2023-01-01 12:13:14.123456+00:00'
figure out its subsecond precision so we can construct types like DATETIME(6)
Note that in practice, this is either 3 or 6 digits (3 = millisecond precision, 6 = microsecond precision)
- 6 is the maximum because strftime's '%f' formats to microseconds and almost every database supports microsecond precision in timestamps
- Except Presto/Trino which in most cases only supports millisecond precision but will still honour '%f' and format to microseconds (replacing the remaining 3 digits with 0's)
- Python prior to 3.11 only supports 0, 3 or 6 digits in a timestamp literal. Any other amounts will throw a 'ValueError: Invalid isoformat string:' error
"""
try:
parsed = datetime.datetime.fromisoformat(timestamp_literal)
subsecond_digit_count = len(str(parsed.microsecond).rstrip("0"))
precision = 0
if subsecond_digit_count > 3:
precision = 6
elif subsecond_digit_count > 0:
precision = 3
return precision
except ValueError:
return 0

View file

@ -317,10 +317,14 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
)
for join in expression.args.get("joins") or []:
unnest = join.this
join_expr = join.this
is_lateral = isinstance(join_expr, exp.Lateral)
unnest = join_expr.this if is_lateral else join_expr
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
expression.args["joins"].remove(join)