Merging upstream version 11.4.5.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
0a06643852
commit
88f99e1c27
131 changed files with 53004 additions and 37079 deletions
|
@ -26,6 +26,7 @@ from sqlglot.helper import (
|
|||
AutoName,
|
||||
camel_to_snake_case,
|
||||
ensure_collection,
|
||||
ensure_list,
|
||||
seq_get,
|
||||
split_num_words,
|
||||
subclasses,
|
||||
|
@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
key = "expression"
|
||||
arg_types = {"this": True}
|
||||
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
|
||||
__slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
|
||||
|
||||
def __init__(self, **args: t.Any):
|
||||
self.args: t.Dict[str, t.Any] = args
|
||||
|
@ -93,22 +94,30 @@ class Expression(metaclass=_Expression):
|
|||
self.comments: t.Optional[t.List[str]] = None
|
||||
self._type: t.Optional[DataType] = None
|
||||
self._meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
self._hash: t.Optional[int] = None
|
||||
|
||||
for arg_key, value in self.args.items():
|
||||
self._set_parent(arg_key, value)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
|
||||
return type(self) is type(other) and hash(self) == hash(other)
|
||||
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
args = (self.args.get(k) for k in self.arg_types)
|
||||
|
||||
return tuple(
|
||||
(tuple(_norm_arg(a) for a in arg) if arg else None)
|
||||
if type(arg) is list
|
||||
else (_norm_arg(arg) if arg is not None and arg is not False else None)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.key,
|
||||
tuple(
|
||||
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
|
||||
),
|
||||
)
|
||||
)
|
||||
if self._hash is not None:
|
||||
return self._hash
|
||||
|
||||
return hash((self.__class__, self.hashable_args))
|
||||
|
||||
@property
|
||||
def this(self):
|
||||
|
@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
new = deepcopy(self)
|
||||
new.parent = self.parent
|
||||
for item, parent, _ in new.bfs():
|
||||
if isinstance(item, Expression) and parent:
|
||||
item.parent = parent
|
||||
return new
|
||||
|
||||
def append(self, arg_key, value):
|
||||
|
@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
|
|||
self._set_parent(arg_key, value)
|
||||
|
||||
def _set_parent(self, arg_key, value):
|
||||
if isinstance(value, Expression):
|
||||
if hasattr(value, "parent"):
|
||||
value.parent = self
|
||||
value.arg_key = arg_key
|
||||
elif isinstance(value, list):
|
||||
elif type(value) is list:
|
||||
for v in value:
|
||||
if isinstance(v, Expression):
|
||||
if hasattr(v, "parent"):
|
||||
v.parent = self
|
||||
v.arg_key = arg_key
|
||||
|
||||
|
@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
|
|||
return self.parent.depth + 1
|
||||
return 0
|
||||
|
||||
def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
|
||||
"""Yields the key and expression for all arguments, exploding list args."""
|
||||
for k, vs in self.args.items():
|
||||
if type(vs) is list:
|
||||
for v in vs:
|
||||
if hasattr(v, "parent"):
|
||||
yield k, v
|
||||
else:
|
||||
if hasattr(vs, "parent"):
|
||||
yield k, vs
|
||||
|
||||
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
|
||||
"""
|
||||
Returns the first node in this tree which matches at least one of
|
||||
|
@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
|
|||
Returns:
|
||||
The generator object.
|
||||
"""
|
||||
for expression, _, _ in self.walk(bfs=bfs):
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
|
||||
|
@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
return self.find_ancestor(Select)
|
||||
|
||||
@property
|
||||
def same_parent(self):
|
||||
"""Returns if the parent is the same class as itself."""
|
||||
return type(self.parent) is self.__class__
|
||||
|
||||
def root(self) -> Expression:
|
||||
"""
|
||||
Returns the root expression of this tree.
|
||||
|
@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
|
|||
if prune and prune(self, parent, key):
|
||||
return
|
||||
|
||||
for k, v in self.args.items():
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
yield from node.dfs(self, k, prune)
|
||||
for k, v in self.iter_expressions():
|
||||
yield from v.dfs(self, k, prune)
|
||||
|
||||
def bfs(self, prune=None):
|
||||
"""
|
||||
|
@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
|
|||
if prune and prune(item, parent, key):
|
||||
continue
|
||||
|
||||
if isinstance(item, Expression):
|
||||
for k, v in item.args.items():
|
||||
for node in ensure_collection(v):
|
||||
if isinstance(node, Expression):
|
||||
queue.append((node, item, k))
|
||||
for k, v in item.iter_expressions():
|
||||
queue.append((v, item, k))
|
||||
|
||||
def unnest(self):
|
||||
"""
|
||||
Returns the first non parenthesis child or self.
|
||||
"""
|
||||
expression = self
|
||||
while isinstance(expression, Paren):
|
||||
while type(expression) is Paren:
|
||||
expression = expression.this
|
||||
return expression
|
||||
|
||||
|
@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
Returns unnested operands as a tuple.
|
||||
"""
|
||||
return tuple(arg.unnest() for arg in self.args.values() if arg)
|
||||
return tuple(arg.unnest() for _, arg in self.iter_expressions())
|
||||
|
||||
def flatten(self, unnest=True):
|
||||
"""
|
||||
|
@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
A AND B AND C -> [A, B, C]
|
||||
"""
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
|
||||
if not isinstance(node, self.__class__):
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
|
||||
if not type(node) is self.__class__:
|
||||
yield node.unnest() if unnest else node
|
||||
|
||||
def __str__(self):
|
||||
|
@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
|
|||
v._to_s(hide_missing=hide_missing, level=level + 1)
|
||||
if hasattr(v, "_to_s")
|
||||
else str(v)
|
||||
for v in ensure_collection(vs)
|
||||
for v in ensure_list(vs)
|
||||
if v is not None
|
||||
)
|
||||
for k, vs in self.args.items()
|
||||
|
@ -812,6 +829,10 @@ class Describe(Expression):
|
|||
arg_types = {"this": True, "kind": False}
|
||||
|
||||
|
||||
class Pragma(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class Set(Expression):
|
||||
arg_types = {"expressions": False}
|
||||
|
||||
|
@ -1170,6 +1191,7 @@ class Drop(Expression):
|
|||
"temporary": False,
|
||||
"materialized": False,
|
||||
"cascade": False,
|
||||
"constraints": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1232,11 +1254,11 @@ class Identifier(Expression):
|
|||
def quoted(self):
|
||||
return bool(self.args.get("quoted"))
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.key, self.this.lower()))
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
if self.quoted and any(char.isupper() for char in self.this):
|
||||
return (self.this, self.quoted)
|
||||
return self.this.lower()
|
||||
|
||||
@property
|
||||
def output_name(self):
|
||||
|
@ -1322,15 +1344,9 @@ class Limit(Expression):
|
|||
class Literal(Condition):
|
||||
arg_types = {"this": True, "is_string": True}
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Literal)
|
||||
and self.this == other.this
|
||||
and self.args["is_string"] == other.args["is_string"]
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.key, self.this, self.args["is_string"]))
|
||||
@property
|
||||
def hashable_args(self) -> t.Any:
|
||||
return (self.this, self.args.get("is_string"))
|
||||
|
||||
@classmethod
|
||||
def number(cls, number) -> Literal:
|
||||
|
@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
|
|||
instance = _maybe_copy(self, copy)
|
||||
return Subquery(
|
||||
this=instance,
|
||||
alias=TableAlias(this=to_identifier(alias)),
|
||||
alias=TableAlias(this=to_identifier(alias)) if alias else None,
|
||||
)
|
||||
|
||||
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
|
||||
|
@ -2058,6 +2074,7 @@ class Lock(Expression):
|
|||
class Select(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"kind": False,
|
||||
"expressions": False,
|
||||
"hint": False,
|
||||
"distinct": False,
|
||||
|
@ -3595,6 +3612,21 @@ class Initcap(Func):
|
|||
pass
|
||||
|
||||
|
||||
class JSONKeyValue(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
arg_types = {
|
||||
"expressions": False,
|
||||
"null_handling": False,
|
||||
"unique_keys": False,
|
||||
"return_type": False,
|
||||
"format_json": False,
|
||||
"encoding": False,
|
||||
}
|
||||
|
||||
|
||||
class JSONBContains(Binary):
|
||||
_sql_names = ["JSONB_CONTAINS"]
|
||||
|
||||
|
@ -3766,8 +3798,10 @@ class RegexpILike(Func):
|
|||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
||||
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
|
||||
# limit is the number of times a pattern is applied
|
||||
class RegexpSplit(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
arg_types = {"this": True, "expression": True, "limit": False}
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
|
@ -3967,25 +4001,8 @@ class When(Func):
|
|||
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
|
||||
|
||||
|
||||
def _norm_args(expression):
|
||||
args = {}
|
||||
|
||||
for k, arg in expression.args.items():
|
||||
if isinstance(arg, list):
|
||||
arg = [_norm_arg(a) for a in arg]
|
||||
if not arg:
|
||||
arg = None
|
||||
else:
|
||||
arg = _norm_arg(arg)
|
||||
|
||||
if arg is not None and arg is not False:
|
||||
args[k] = arg
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def _norm_arg(arg):
|
||||
return arg.lower() if isinstance(arg, str) else arg
|
||||
return arg.lower() if type(arg) is str else arg
|
||||
|
||||
|
||||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||
|
@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
|
|||
elif isinstance(name, str):
|
||||
identifier = Identifier(
|
||||
this=name,
|
||||
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
|
||||
quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
|
||||
|
@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
|
|||
return sql_path
|
||||
if not isinstance(sql_path, str):
|
||||
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
|
||||
table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
|
||||
return Column(this=column_name, table=table_name, **kwargs)
|
||||
return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
|
||||
|
||||
|
||||
def alias_(
|
||||
|
@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
def column(
|
||||
col: str | Identifier,
|
||||
table: t.Optional[str | Identifier] = None,
|
||||
schema: t.Optional[str | Identifier] = None,
|
||||
db: t.Optional[str | Identifier] = None,
|
||||
catalog: t.Optional[str | Identifier] = None,
|
||||
quoted: t.Optional[bool] = None,
|
||||
) -> Column:
|
||||
"""
|
||||
|
@ -4681,7 +4698,8 @@ def column(
|
|||
Args:
|
||||
col: column name
|
||||
table: table name
|
||||
schema: schema name
|
||||
db: db name
|
||||
catalog: catalog name
|
||||
quoted: whether or not to force quote each part
|
||||
Returns:
|
||||
Column: column instance
|
||||
|
@ -4689,7 +4707,8 @@ def column(
|
|||
return Column(
|
||||
this=to_identifier(col, quoted=quoted),
|
||||
table=to_identifier(table, quoted=quoted),
|
||||
schema=to_identifier(schema, quoted=quoted),
|
||||
db=to_identifier(db, quoted=quoted),
|
||||
catalog=to_identifier(catalog, quoted=quoted),
|
||||
)
|
||||
|
||||
|
||||
|
@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
|
|||
Replace children of an expression with the result of a lambda fun(child) -> exp.
|
||||
"""
|
||||
for k, v in expression.args.items():
|
||||
is_list_arg = isinstance(v, list)
|
||||
is_list_arg = type(v) is list
|
||||
|
||||
child_nodes = v if is_list_arg else [v]
|
||||
new_child_nodes = []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue