1
0
Fork 0

Merging upstream version 20.9.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:19:14 +01:00
parent 9421b254ec
commit 37a231f554
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 78309 additions and 59609 deletions

View file

@ -16,6 +16,7 @@ import datetime
import math
import numbers
import re
import textwrap
import typing as t
from collections import deque
from copy import deepcopy
@ -35,6 +36,8 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from typing_extensions import Literal as Lit
from sqlglot.dialects.dialect import DialectType
@ -242,6 +245,9 @@ class Expression(metaclass=_Expression):
def is_type(self, *dtypes) -> bool:
return self.type is not None and self.type.is_type(*dtypes)
def is_leaf(self) -> bool:
return not any(isinstance(v, (Expression, list)) for v in self.args.values())
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
@ -497,7 +503,14 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self) -> str:
return self._to_s()
return _to_s(self)
def to_s(self) -> str:
"""
Same as __repr__, but includes additional information which can be useful
for debugging, like empty or missing args and the AST nodes' object IDs.
"""
return _to_s(self, verbose=True)
def sql(self, dialect: DialectType = None, **opts) -> str:
"""
@ -514,30 +527,6 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect).generate(self, **opts)
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = {
k: ", ".join(
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
}
args["comments"] = self.comments
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
right += ")"
return indent + left + right
def transform(self, fun, *args, copy=True, **kwargs):
"""
Recursively visits all tree nodes (excluding already transformed ones)
@ -580,8 +569,9 @@ class Expression(metaclass=_Expression):
For example::
>>> tree = Select().select("x").from_("tbl")
>>> tree.find(Column).replace(Column(this="y"))
(COLUMN this: y)
>>> tree.find(Column).replace(column("y"))
Column(
this=Identifier(this=y, quoted=False))
>>> tree.sql()
'SELECT y FROM tbl'
@ -831,6 +821,9 @@ class Expression(metaclass=_Expression):
div.args["safe"] = safe
return div
def desc(self, nulls_first: bool = False) -> Ordered:
return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first)
def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
@ -1109,7 +1102,7 @@ class Clone(Expression):
class Describe(Expression):
arg_types = {"this": True, "kind": False, "expressions": False}
arg_types = {"this": True, "extended": False, "kind": False, "expressions": False}
class Kill(Expression):
@ -1124,6 +1117,10 @@ class Set(Expression):
arg_types = {"expressions": False, "unset": False, "tag": False}
class Heredoc(Expression):
arg_types = {"this": True, "tag": False}
class SetItem(Expression):
arg_types = {
"this": False,
@ -1937,7 +1934,13 @@ class Join(Expression):
class Lateral(UDTF):
arg_types = {"this": True, "view": False, "outer": False, "alias": False}
arg_types = {
"this": True,
"view": False,
"outer": False,
"alias": False,
"cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY
}
class MatchRecognize(Expression):
@ -1964,7 +1967,12 @@ class Offset(Expression):
class Order(Expression):
arg_types = {"this": False, "expressions": True, "interpolate": False}
arg_types = {
"this": False,
"expressions": True,
"interpolate": False,
"siblings": False,
}
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
@ -2002,6 +2010,11 @@ class AutoIncrementProperty(Property):
arg_types = {"this": True}
# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
class AutoRefreshProperty(Property):
arg_types = {"this": True}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
@ -2259,6 +2272,10 @@ class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
class SqlReadWriteProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
@ -2543,7 +2560,6 @@ class Table(Expression):
"version": False,
"format": False,
"pattern": False,
"index": False,
"ordinality": False,
"when": False,
}
@ -2585,6 +2601,14 @@ class Table(Expression):
return parts
def to_column(self, copy: bool = True) -> Alias | Column | Dot:
parts = self.parts
col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
alias = self.args.get("alias")
if alias:
col = alias_(col, alias.this, copy=copy)
return col
class Union(Subqueryable):
arg_types = {
@ -2694,6 +2718,14 @@ class Unnest(UDTF):
"offset": False,
}
@property
def selects(self) -> t.List[Expression]:
columns = super().selects
offset = self.args.get("offset")
if offset:
columns = columns + [to_identifier("offset") if offset is True else offset]
return columns
class Update(Expression):
arg_types = {
@ -3368,7 +3400,7 @@ class Select(Subqueryable):
return Create(
this=table_expression,
kind="table",
kind="TABLE",
expression=instance,
properties=properties_expression,
)
@ -3488,7 +3520,6 @@ class TableSample(Expression):
"rows": False,
"size": False,
"seed": False,
"kind": False,
}
@ -3517,6 +3548,10 @@ class Pivot(Expression):
"include_nulls": False,
}
@property
def unpivot(self) -> bool:
return bool(self.args.get("unpivot"))
class Window(Condition):
arg_types = {
@ -3604,6 +3639,7 @@ class DataType(Expression):
BOOLEAN = auto()
CHAR = auto()
DATE = auto()
DATE32 = auto()
DATEMULTIRANGE = auto()
DATERANGE = auto()
DATETIME = auto()
@ -3631,6 +3667,8 @@ class DataType(Expression):
INTERVAL = auto()
IPADDRESS = auto()
IPPREFIX = auto()
IPV4 = auto()
IPV6 = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
@ -3729,6 +3767,7 @@ class DataType(Expression):
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
Type.DATE,
Type.DATE32,
Type.DATETIME,
Type.DATETIME64,
}
@ -4100,6 +4139,12 @@ class Alias(Expression):
return self.alias
# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but
# other dialects require identifiers. This enables us to transpile between them easily.
class PivotAlias(Alias):
pass
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@ -4108,6 +4153,11 @@ class Aliases(Expression):
return self.expressions
# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html
class AtIndex(Expression):
arg_types = {"this": True, "expression": True}
class AtTimeZone(Expression):
arg_types = {"this": True, "zone": True}
@ -4154,16 +4204,16 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
UNABBREVIATED_UNIT_NAME = {
"d": "day",
"h": "hour",
"m": "minute",
"ms": "millisecond",
"ns": "nanosecond",
"q": "quarter",
"s": "second",
"us": "microsecond",
"w": "week",
"y": "year",
"D": "DAY",
"H": "HOUR",
"M": "MINUTE",
"MS": "MILLISECOND",
"NS": "NANOSECOND",
"Q": "QUARTER",
"S": "SECOND",
"US": "MICROSECOND",
"W": "WEEK",
"Y": "YEAR",
}
VAR_LIKE = (Column, Literal, Var)
@ -4171,9 +4221,11 @@ class TimeUnit(Expression):
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, self.VAR_LIKE):
args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
args["unit"] = Var(
this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
unit.set("this", Var(this=unit.this.name.upper()))
super().__init__(**args)
@ -4301,6 +4353,20 @@ class Anonymous(Func):
is_var_len_args = True
class AnonymousAggFunc(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators
class CombinedAggFunc(AnonymousAggFunc):
arg_types = {"this": True, "expressions": False, "parts": True}
class CombinedParameterizedAgg(ParameterizedAgg):
arg_types = {"this": True, "expressions": True, "params": True, "parts": True}
# https://docs.snowflake.com/en/sql-reference/functions/hll
# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html
class Hll(AggFunc):
@ -4381,7 +4447,7 @@ class ArraySort(Func):
class ArraySum(Func):
pass
arg_types = {"this": True, "expression": False}
class ArrayUnionAgg(AggFunc):
@ -4498,7 +4564,7 @@ class Count(AggFunc):
class CountIf(AggFunc):
pass
_sql_names = ["COUNT_IF", "COUNTIF"]
class CurrentDate(Func):
@ -4537,6 +4603,17 @@ class DateDiff(Func, TimeUnit):
class DateTrunc(Func):
arg_types = {"unit": True, "this": True, "zone": False}
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, TimeUnit.VAR_LIKE):
args["unit"] = Literal.string(
(TimeUnit.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
)
elif isinstance(unit, Week):
unit.set("this", Literal.string(unit.this.name.upper()))
super().__init__(**args)
@property
def unit(self) -> Expression:
return self.args["unit"]
@ -4582,8 +4659,9 @@ class MonthsBetween(Func):
arg_types = {"this": True, "expression": True, "roundoff": False}
class LastDateOfMonth(Func):
pass
class LastDay(Func, TimeUnit):
_sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"]
arg_types = {"this": True, "unit": False}
class Extract(Func):
@ -4627,10 +4705,22 @@ class TimeTrunc(Func, TimeUnit):
class DateFromParts(Func):
_sql_names = ["DATEFROMPARTS"]
_sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"]
arg_types = {"year": True, "month": True, "day": True}
class TimeFromParts(Func):
_sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"]
arg_types = {
"hour": True,
"min": True,
"sec": True,
"nano": False,
"fractions": False,
"precision": False,
}
class DateStrToDate(Func):
pass
@ -4754,6 +4844,16 @@ class JSONObject(Func):
}
class JSONObjectAgg(AggFunc):
arg_types = {
"expressions": False,
"null_handling": False,
"unique_keys": False,
"return_type": False,
"encoding": False,
}
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html
class JSONArray(Func):
arg_types = {
@ -4841,6 +4941,15 @@ class ParseJSON(Func):
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/get_path
class GetPath(Func):
arg_types = {"this": True, "expression": True}
@property
def output_name(self) -> str:
return self.expression.output_name
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -5026,7 +5135,7 @@ class RegexpReplace(Func):
arg_types = {
"this": True,
"expression": True,
"replacement": True,
"replacement": False,
"position": False,
"occurrence": False,
"parameters": False,
@ -5052,8 +5161,10 @@ class Repeat(Func):
arg_types = {"this": True, "times": True}
# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
# tsql third argument function == trunctaion if not 0
class Round(Func):
arg_types = {"this": True, "decimals": False}
arg_types = {"this": True, "decimals": False, "truncate": False}
class RowNumber(Func):
@ -5228,6 +5339,10 @@ class TsOrDsToDate(Func):
arg_types = {"this": True, "format": False}
class TsOrDsToTime(Func):
pass
class TsOrDiToDi(Func):
pass
@ -5236,6 +5351,11 @@ class Unhex(Func):
pass
# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date
class UnixDate(Func):
pass
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@ -5245,10 +5365,16 @@ class UnixToStr(Func):
class UnixToTime(Func):
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
MICROS = Literal.string("micros")
NANOS = Literal.string("nanos")
SECONDS = Literal.number(0)
DECIS = Literal.number(1)
CENTIS = Literal.number(2)
MILLIS = Literal.number(3)
DECIMILLIS = Literal.number(4)
CENTIMILLIS = Literal.number(5)
MICROS = Literal.number(6)
DECIMICROS = Literal.number(7)
CENTIMICROS = Literal.number(8)
NANOS = Literal.number(9)
class UnixToTimeStr(Func):
@ -5256,8 +5382,7 @@ class UnixToTimeStr(Func):
class TimestampFromParts(Func):
"""Constructs a timestamp given its constituent parts."""
_sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
arg_types = {
"year": True,
"month": True,
@ -5265,6 +5390,9 @@ class TimestampFromParts(Func):
"hour": True,
"min": True,
"sec": True,
"nano": False,
"zone": False,
"milli": False,
}
@ -5358,9 +5486,9 @@ def maybe_parse(
Example:
>>> maybe_parse("1")
(LITERAL this: 1, is_string: False)
Literal(this=1, is_string=False)
>>> maybe_parse(to_identifier("x"))
(IDENTIFIER this: x, quoted: False)
Identifier(this=x, quoted=False)
Args:
sql_or_expression: the SQL code string or an expression
@ -5407,6 +5535,39 @@ def maybe_copy(instance, copy=True):
return instance.copy() if copy and instance else instance
def _to_s(node: t.Any, verbose: bool = False, level: int = 0) -> str:
"""Generate a textual representation of an Expression tree"""
indent = "\n" + (" " * (level + 1))
delim = f",{indent}"
if isinstance(node, Expression):
args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose}
if (node.type or verbose) and not isinstance(node, DataType):
args["_type"] = node.type
if node.comments or verbose:
args["_comments"] = node.comments
if verbose:
args["_id"] = id(node)
# Inline leaves for a more compact representation
if node.is_leaf():
indent = ""
delim = ", "
items = delim.join([f"{k}={_to_s(v, verbose, level + 1)}" for k, v in args.items()])
return f"{node.__class__.__name__}({indent}{items})"
if isinstance(node, list):
items = delim.join(_to_s(i, verbose, level + 1) for i in node)
items = f"{indent}{items}" if items else ""
return f"[{items}]"
# Indent multiline strings to match the current level
return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines())
def _is_wrong_expression(expression, into):
return isinstance(expression, Expression) and not isinstance(expression, into)
@ -5816,7 +5977,7 @@ def delete(
def insert(
expression: ExpOrStr,
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
columns: t.Optional[t.Sequence[str | Identifier]] = None,
overwrite: t.Optional[bool] = None,
returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
@ -5847,15 +6008,7 @@ def insert(
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)
if columns:
this = _apply_list_builder(
*columns,
instance=Schema(this=this),
arg="expressions",
into=Identifier,
copy=False,
dialect=dialect,
**opts,
)
this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns])
insert = Insert(this=this, expression=expr, overwrite=overwrite)
@ -6073,7 +6226,7 @@ def to_interval(interval: str | Literal) -> Interval:
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2)),
unit=Var(this=interval_parts.group(2).upper()),
)
@ -6219,13 +6372,44 @@ def subquery(
return Select().from_(expression, dialect=dialect, **opts)
@t.overload
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
*,
fields: t.Collection[t.Union[str, Identifier]],
quoted: t.Optional[bool] = None,
copy: bool = True,
) -> Dot:
pass
@t.overload
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
*,
fields: Lit[None] = None,
quoted: t.Optional[bool] = None,
copy: bool = True,
) -> Column:
pass
def column(
col,
table=None,
db=None,
catalog=None,
*,
fields=None,
quoted=None,
copy=True,
):
"""
Build a Column.
@ -6234,18 +6418,24 @@ def column(
table: Table name.
db: Database name.
catalog: Catalog name.
fields: Additional fields using dots.
quoted: Whether to force quotes on the column's identifiers.
copy: Whether or not to copy identifiers if passed in.
Returns:
The new Column instance.
"""
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
this = Column(
this=to_identifier(col, quoted=quoted, copy=copy),
table=to_identifier(table, quoted=quoted, copy=copy),
db=to_identifier(db, quoted=quoted, copy=copy),
catalog=to_identifier(catalog, quoted=quoted, copy=copy),
)
if fields:
this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields)))
return this
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
"""Cast an expression to a data type.
@ -6333,10 +6523,10 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
Example:
>>> repr(var('x'))
'(VAR this: x)'
'Var(this=x)'
>>> repr(var(column('x', table='y')))
'(VAR this: x)'
'Var(this=x)'
Args:
name: The name of the var or an expression who's name will become the var.