1
0
Fork 0

Merging upstream version 11.4.5.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:48:10 +01:00
parent 0a06643852
commit 88f99e1c27
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
131 changed files with 53004 additions and 37079 deletions

View file

@ -26,6 +26,7 @@ from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_collection,
ensure_list,
seq_get,
split_num_words,
subclasses,
@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@ -93,22 +94,30 @@ class Expression(metaclass=_Expression):
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
self._hash: t.Optional[int] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
return type(self) is type(other) and hash(self) == hash(other)
@property
def hashable_args(self) -> t.Any:
args = (self.args.get(k) for k in self.arg_types)
return tuple(
(tuple(_norm_arg(a) for a in arg) if arg else None)
if type(arg) is list
else (_norm_arg(arg) if arg is not None and arg is not False else None)
for arg in args
)
def __hash__(self) -> int:
return hash(
(
self.key,
tuple(
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
),
)
)
if self._hash is not None:
return self._hash
return hash((self.__class__, self.hashable_args))
@property
def this(self):
@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
"""
new = deepcopy(self)
new.parent = self.parent
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
item.parent = parent
return new
def append(self, arg_key, value):
@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
self._set_parent(arg_key, value)
def _set_parent(self, arg_key, value):
if isinstance(value, Expression):
if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
elif isinstance(value, list):
elif type(value) is list:
for v in value:
if isinstance(v, Expression):
if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
"""Yields the key and expression for all arguments, exploding list args."""
for k, vs in self.args.items():
if type(vs) is list:
for v in vs:
if hasattr(v, "parent"):
yield k, v
else:
if hasattr(vs, "parent"):
yield k, vs
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
for expression, _, _ in self.walk(bfs=bfs):
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
@property
def same_parent(self):
"""Returns if the parent is the same class as itself."""
return type(self.parent) is self.__class__
def root(self) -> Expression:
"""
Returns the root expression of this tree.
@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
if prune and prune(self, parent, key):
return
for k, v in self.args.items():
for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
for k, v in self.iter_expressions():
yield from v.dfs(self, k, prune)
def bfs(self, prune=None):
"""
@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
if prune and prune(item, parent, key):
continue
if isinstance(item, Expression):
for k, v in item.args.items():
for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
for k, v in item.iter_expressions():
queue.append((v, item, k))
def unnest(self):
"""
Returns the first non parenthesis child or self.
"""
expression = self
while isinstance(expression, Paren):
while type(expression) is Paren:
expression = expression.this
return expression
@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
return tuple(arg.unnest() for arg in self.args.values() if arg)
return tuple(arg.unnest() for _, arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
yield node.unnest() if unnest else node
def __str__(self):
@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_collection(vs)
for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
@ -812,6 +829,10 @@ class Describe(Expression):
arg_types = {"this": True, "kind": False}
class Pragma(Expression):
pass
class Set(Expression):
arg_types = {"expressions": False}
@ -1170,6 +1191,7 @@ class Drop(Expression):
"temporary": False,
"materialized": False,
"cascade": False,
"constraints": False,
}
@ -1232,11 +1254,11 @@ class Identifier(Expression):
def quoted(self):
return bool(self.args.get("quoted"))
def __eq__(self, other):
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@property
def hashable_args(self) -> t.Any:
if self.quoted and any(char.isupper() for char in self.this):
return (self.this, self.quoted)
return self.this.lower()
@property
def output_name(self):
@ -1322,15 +1344,9 @@ class Limit(Expression):
class Literal(Condition):
arg_types = {"this": True, "is_string": True}
def __eq__(self, other):
return (
isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@property
def hashable_args(self) -> t.Any:
return (self.this, self.args.get("is_string"))
@classmethod
def number(cls, number) -> Literal:
@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
alias=TableAlias(this=to_identifier(alias)),
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
@ -2058,6 +2074,7 @@ class Lock(Expression):
class Select(Subqueryable):
arg_types = {
"with": False,
"kind": False,
"expressions": False,
"hint": False,
"distinct": False,
@ -3595,6 +3612,21 @@ class Initcap(Func):
pass
class JSONKeyValue(Expression):
arg_types = {"this": True, "expression": True}
class JSONObject(Func):
arg_types = {
"expressions": False,
"null_handling": False,
"unique_keys": False,
"return_type": False,
"format_json": False,
"encoding": False,
}
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@ -3766,8 +3798,10 @@ class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
# limit is the number of times a pattern is applied
class RegexpSplit(Func):
arg_types = {"this": True, "expression": True}
arg_types = {"this": True, "expression": True, "limit": False}
class Repeat(Func):
@ -3967,25 +4001,8 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
def _norm_args(expression):
args = {}
for k, arg in expression.args.items():
if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg]
if not arg:
arg = None
else:
arg = _norm_arg(arg)
if arg is not None and arg is not False:
args[k] = arg
return args
def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg
return arg.lower() if type(arg) is str else arg
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
elif isinstance(name, str):
identifier = Identifier(
this=name,
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
)
else:
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs)
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
def alias_(
@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
schema: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
quoted: t.Optional[bool] = None,
) -> Column:
"""
@ -4681,7 +4698,8 @@ def column(
Args:
col: column name
table: table name
schema: schema name
db: db name
catalog: catalog name
quoted: whether or not to force quote each part
Returns:
Column: column instance
@ -4689,7 +4707,8 @@ def column(
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
schema=to_identifier(schema, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
)
@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
for k, v in expression.args.items():
is_list_arg = isinstance(v, list)
is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
new_child_nodes = []