1
0
Fork 0

Merging upstream version 11.7.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:52:09 +01:00
parent 0c053462ae
commit 8d96084fad
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
144 changed files with 44104 additions and 39367 deletions

View file

@ -701,6 +701,119 @@ class Condition(Expression):
"""
return not_(self)
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)
def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
)
def like(self, other: ExpOrStr) -> Like:
return self._binop(Like, other)
def ilike(self, other: ExpOrStr) -> ILike:
return self._binop(ILike, other)
def eq(self, other: ExpOrStr) -> EQ:
return self._binop(EQ, other)
def neq(self, other: ExpOrStr) -> NEQ:
return self._binop(NEQ, other)
def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)
def __lt__(self, other: ExpOrStr) -> LT:
return self._binop(LT, other)
def __le__(self, other: ExpOrStr) -> LTE:
return self._binop(LTE, other)
def __gt__(self, other: ExpOrStr) -> GT:
return self._binop(GT, other)
def __ge__(self, other: ExpOrStr) -> GTE:
return self._binop(GTE, other)
def __add__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other)
def __radd__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other, reverse=True)
def __sub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other)
def __rsub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other, reverse=True)
def __mul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other)
def __rmul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other, reverse=True)
def __truediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other)
def __rtruediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other, reverse=True)
def __floordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other)
def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other, reverse=True)
def __mod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other)
def __rmod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other, reverse=True)
def __pow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other)
def __rpow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other, reverse=True)
def __and__(self, other: ExpOrStr) -> And:
return self._binop(And, other)
def __rand__(self, other: ExpOrStr) -> And:
return self._binop(And, other, reverse=True)
def __or__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other)
def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))
def __invert__(self) -> Not:
return not_(self)
class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
@ -818,7 +931,6 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
"volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html
class OnUpdateColumnConstraint(ColumnConstraintKind):
pass
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
arg_types = {"desc": False}
@ -1197,6 +1314,7 @@ class Drop(Expression):
"materialized": False,
"cascade": False,
"constraints": False,
"purge": False,
}
@ -1287,6 +1405,7 @@ class Insert(Expression):
"with": False,
"this": True,
"expression": False,
"conflict": False,
"returning": False,
"overwrite": False,
"exists": False,
@ -1295,6 +1414,16 @@ class Insert(Expression):
}
class OnConflict(Expression):
arg_types = {
"duplicate": False,
"expressions": False,
"nothing": False,
"key": False,
"constraint": False,
}
class Returning(Expression):
arg_types = {"expressions": True}
@ -1326,7 +1455,12 @@ class Partition(Expression):
class Fetch(Expression):
arg_types = {"direction": False, "count": False}
arg_types = {
"direction": False,
"count": False,
"percent": False,
"with_ties": False,
}
class Group(Expression):
@ -1374,6 +1508,7 @@ class Join(Expression):
"kind": False,
"using": False,
"natural": False,
"hint": False,
}
@property
@ -1384,6 +1519,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
@property
def hint(self):
return self.text("hint").upper()
@property
def alias_or_name(self):
return self.this.alias_or_name
@ -1475,6 +1614,7 @@ class MatchRecognize(Expression):
"after": False,
"pattern": False,
"define": False,
"alias": False,
}
@ -1582,6 +1722,10 @@ class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class InputOutputFormat(Expression):
arg_types = {"input_format": False, "output_format": False}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
@ -1646,6 +1790,10 @@ class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False, "table": False}
class RowFormatProperty(Property):
arg_types = {"this": True}
class RowFormatDelimitedProperty(Property):
# https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
arg_types = {
@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class StabilityProperty(Property):
arg_types = {"this": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
@ -1695,8 +1847,8 @@ class TransientProperty(Property):
arg_types = {"this": False}
class VolatilityProperty(Property):
arg_types = {"this": True}
class VolatileProperty(Property):
arg_types = {"this": False}
class WithDataProperty(Property):
@ -1726,6 +1878,7 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"RETURNS": ReturnsProperty,
"ROW_FORMAT": RowFormatProperty,
"SORTKEY": SortKeyProperty,
"TABLE_FORMAT": TableFormatProperty,
}
@ -2721,6 +2874,7 @@ class Pivot(Expression):
"expressions": True,
"field": True,
"unpivot": True,
"columns": False,
}
@ -2731,6 +2885,8 @@ class Window(Expression):
"order": False,
"spec": False,
"alias": False,
"over": False,
"first": False,
}
@ -2816,6 +2972,7 @@ class DataType(Expression):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
BIGDECIMAL = auto()
BIT = auto()
BOOLEAN = auto()
JSON = auto()
@ -2964,7 +3121,7 @@ class DropPartition(Expression):
# Binary expressions like (ADD a b)
class Binary(Expression):
class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
@ -2980,7 +3137,7 @@ class Add(Binary):
pass
class Connector(Binary, Condition):
class Connector(Binary):
pass
@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary):
# Unary Expressions
# (NOT a)
class Unary(Expression):
class Unary(Condition):
pass
@ -3150,11 +3307,11 @@ class BitwiseNot(Unary):
pass
class Not(Unary, Condition):
class Not(Unary):
pass
class Paren(Unary, Condition):
class Paren(Unary):
arg_types = {"this": True, "with": False}
@ -3162,7 +3319,6 @@ class Neg(Unary):
pass
# Special Functions
class Alias(Expression):
arg_types = {"this": True, "alias": False}
@ -3381,6 +3537,16 @@ class AnyValue(AggFunc):
class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
return this
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.set("default", maybe_parse(condition, **opts))
return this
class Cast(Func):
arg_types = {"this": True, "to": True}
@ -3719,6 +3885,10 @@ class Map(Func):
arg_types = {"keys": False, "values": False}
class StarMap(Func):
pass
class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True
@ -3734,6 +3904,10 @@ class Max(AggFunc):
is_var_len_args = True
class MD5(Func):
_sql_names = ["MD5"]
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -3840,6 +4014,15 @@ class SetAgg(AggFunc):
pass
class SHA(Func):
_sql_names = ["SHA", "SHA1"]
class SHA2(Func):
_sql_names = ["SHA2"]
arg_types = {"this": True, "length": False}
class SortArray(Func):
arg_types = {"this": True, "asc": False}
@ -4017,6 +4200,12 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
class NextValueFor(Func):
arg_types = {"this": True, "order": False}
def _norm_arg(arg):
return arg.lower() if type(arg) is str else arg
@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
@t.overload
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
into: t.Type[E],
dialect: DialectType = None,
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
@t.overload
def maybe_parse(
sql_or_expression: str | E,
*,
into: t.Optional[IntoType] = None,
dialect: DialectType = None,
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E:
...
def maybe_parse(
sql_or_expression: ExpOrStr,
*,
@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
this = _wrap(this, Connector)
for expression in expressions[1:]:
this = operator(this=this, expression=_wrap_operator(expression))
this = operator(this=this, expression=_wrap(expression, Connector))
return this
def _wrap_operator(expression):
if isinstance(expression, (And, Or, Not)):
expression = Paren(this=expression)
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
if isinstance(expression, kind):
return Paren(this=expression)
return expression
@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
return Not(this=_wrap_operator(this))
return Not(this=_wrap(this, Connector))
def paren(expression) -> Paren:
@ -4657,6 +4872,8 @@ def alias_(
if table:
table_alias = TableAlias(this=alias)
exp = exp.copy() if isinstance(expression, Expression) else exp
exp.set("alias", table_alias)
if not isinstance(table, bool):
@ -4864,16 +5081,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
@ -4883,14 +5106,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")
@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
def expand(
expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True
) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql()
'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */'
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
return subquery
return subquery.transform(_expand, copy=False)
return node
return expression.transform(_expand, copy=copy)
@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
from sqlglot.dialects.dialect import Dialect
converted = [convert(arg) for arg in args]
kwargs = {key: convert(value) for key, value in kwargs.items()}
converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args]
kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())