1
0
Fork 0

Merging upstream version 6.1.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 08:04:41 +01:00
parent 3c6d649c90
commit 08ecea3adf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
61 changed files with 1844 additions and 1555 deletions

View file

@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
return hash(
(
self.key,
tuple(
(k, tuple(v) if isinstance(v, list) else v)
for k, v in _norm_args(self).items()
),
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
)
)
@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
item.parent = parent
return new
def append(self, arg_key, value):
"""
Appends value to arg_key if it's a list or sets it as a new list.
Args:
arg_key (str): name of the list expression arg
value (Any): value to append to the list
"""
if not isinstance(self.args.get(arg_key), list):
self.args[arg_key] = []
self.args[arg_key].append(value)
self._set_parent(arg_key, value)
def set(self, arg_key, value):
"""
Sets `arg` to `value`.
Sets `arg_key` to `value`.
Args:
arg_key (str): name of the expression arg
@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
expression = expression.this
return expression
def unalias(self):
"""
Returns the inner expression if this is an Alias.
"""
if isinstance(self, Alias):
return self.this
return self
def unnest_operands(self):
"""
Returns unnested operands as a tuple.
@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
for node, _, _ in self.dfs(
prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
):
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
yield node.unnest() if unnest else node
@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "to_s")
else str(v)
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
if v is not None
)
@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
new_node.parent = node.parent
return new_node
replace_children(
new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
)
replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
return new_node
def replace(self, expression):
@ -546,6 +558,10 @@ class BitString(Condition):
pass
class HexString(Condition):
pass
class Column(Condition):
arg_types = {"this": True, "table": False}
@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
class AutoIncrementColumnConstraint(Expression):
class ColumnConstraintKind(Expression):
pass
class CheckColumnConstraint(Expression):
class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
class CollateColumnConstraint(Expression):
class CheckColumnConstraint(ColumnConstraintKind):
pass
class CommentColumnConstraint(Expression):
class CollateColumnConstraint(ColumnConstraintKind):
pass
class DefaultColumnConstraint(Expression):
class CommentColumnConstraint(ColumnConstraintKind):
pass
class NotNullColumnConstraint(Expression):
class DefaultColumnConstraint(ColumnConstraintKind):
pass
class PrimaryKeyColumnConstraint(Expression):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
pass
class UniqueColumnConstraint(Expression):
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
pass
class UniqueColumnConstraint(ColumnConstraintKind):
pass
@ -651,9 +676,7 @@ class Identifier(Expression):
return bool(self.args.get("quoted"))
def __eq__(self, other):
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
other.this
)
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@ -709,9 +732,7 @@ class Literal(Condition):
def __eq__(self, other):
return (
isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
@ -733,6 +754,7 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
"natural": False,
}
@property
@ -743,6 +765,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
@property
def alias_or_name(self):
return self.this.alias_or_name
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
"""
Append to or set the ON expressions.
@ -873,10 +899,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
class Table(Expression):
arg_types = {"this": True, "db": False, "catalog": False}
class Tuple(Expression):
arg_types = {"expressions": False}
@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
}
class Table(Expression):
arg_types = {
"this": True,
"db": False,
"catalog": False,
"laterals": False,
"joins": False,
}
class Union(Subqueryable, Expression):
arg_types = {
"with": False,
@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
join.this.replace(join.this.subquery())
if join_type:
side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
if natural:
join.set("natural", True)
if side:
join.set("side", side.text)
if kind:
@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
properties_expression = None
if properties:
properties_str = " ".join(
[
f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
for k, v in properties.items()
]
[f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
@ -1654,6 +1685,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
DATE = auto()
@ -1662,15 +1694,19 @@ class DataType(Expression):
MAP = auto()
UUID = auto()
GEOGRAPHY = auto()
GEOMETRY = auto()
STRUCT = auto()
NULLABLE = auto()
HLLSKETCH = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
@classmethod
def build(cls, dtype, **kwargs):
return DataType(
this=dtype
if isinstance(dtype, DataType.Type)
else DataType.Type[dtype.upper()],
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
)
@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
pass
class SimilarTo(Binary, Predicate):
pass
class Distance(Binary):
pass
class LT(Binary, Predicate):
pass
@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
pass
class RespectNulls(Expression):
pass
# Functions
class Func(Condition):
"""
@ -1924,9 +1972,7 @@ class Func(Condition):
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
non_var_len_arg_keys = (
all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
)
non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {}
arg_idx = 0
@ -1944,9 +1990,7 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
raise NotImplementedError("SQL name is only supported by concrete function implementations")
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@ -2178,6 +2222,10 @@ class Greatest(Func):
is_var_len_args = True
class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
class ApproxQuantile(Quantile):
pass
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@ -2306,8 +2358,10 @@ class Split(Func):
arg_types = {"this": True, "expression": True}
# Start may be omitted in the case of postgres
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
arg_types = {"this": True, "start": True, "length": False}
arg_types = {"this": True, "start": False, "length": False}
class StrPosition(Func):
@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
pass
class Trim(Func):
arg_types = {
"this": True,
"position": False,
"expression": False,
"collation": False,
}
class TsOrDsAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@ -2455,9 +2518,7 @@ def _all_functions():
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
lambda obj: inspect.isclass(obj)
and issubclass(obj, Func)
and obj not in (AggFunc, Anonymous, Func),
lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
def _combine(expressions, operator, dialect=None, **opts):
expressions = [
condition(expression, dialect=dialect, **opts) for expression in expressions
]
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted)
else:
raise ValueError(
f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
)
raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
return identifier