1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,6 +1,9 @@
from __future__ import annotations
import datetime
import numbers
import re
import typing as t
from collections import deque
from copy import deepcopy
from enum import auto
@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_list,
list_get,
ensure_collection,
seq_get,
split_num_words,
subclasses,
)
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
or optional (False).
"""
key = None
key = "Expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "type")
__slots__ = ("args", "parent", "arg_key", "type", "comment")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
self.comment = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
def __hash__(self):
def __hash__(self) -> int:
return hash(
(
self.key,
tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
tuple(
(k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
),
)
)
@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
return field.this
return ""
def find_comment(self, key: str) -> str:
"""
Finds the comment that is attached to a specified child node.
Args:
key: the key of the target child node (e.g. "this", "expression", etc).
Returns:
The comment attached to the child node, or the empty string, if it doesn't exist.
"""
field = self.args.get(key)
return field.comment if isinstance(field, Expression) else ""
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
def __deepcopy__(self, memo):
return self.__class__(**deepcopy(self.args))
copy = self.__class__(**deepcopy(self.args))
copy.comment = self.comment
copy.type = self.type
return copy
def copy(self):
new = deepcopy(self)
@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
return
for k, v in self.args.items():
nodes = ensure_list(v)
for node in nodes:
for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
if isinstance(item, Expression):
for k, v in item.args.items():
nodes = ensure_list(v)
for node in nodes:
for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self.to_s()
def sql(self, dialect=None, **opts):
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect)().generate(self, **opts)
def to_s(self, hide_missing=True, level=0):
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
for v in ensure_collection(vs)
if v is not None
)
for k, vs in self.args.items()
}
args["comment"] = self.comment
args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
pass
class Annotation(Expression):
arg_types = {
"this": True,
"expression": True,
}
@property
def alias(self):
return self.expression.alias_or_name
class Cache(Expression):
arg_types = {
"with": False,
@ -623,6 +635,38 @@ class Describe(Expression):
pass
class Set(Expression):
arg_types = {"expressions": True}
class SetItem(Expression):
arg_types = {
"this": True,
"kind": False,
"collate": False, # MySQL SET NAMES statement
}
class Show(Expression):
arg_types = {
"this": True,
"target": False,
"offset": False,
"limit": False,
"like": False,
"where": False,
"db": False,
"full": False,
"mutex": False,
"query": False,
"channel": False,
"global": False,
"log": False,
"position": False,
"types": False,
}
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
@ -864,18 +908,20 @@ class Literal(Condition):
def __eq__(self, other):
return (
isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
isinstance(other, Literal)
and self.this == other.this
and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@classmethod
def number(cls, number):
def number(cls, number) -> Literal:
return cls(this=str(number), is_string=False)
@classmethod
def string(cls, string):
def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
@ -1087,7 +1133,7 @@ class Properties(Expression):
}
@classmethod
def from_dict(cls, properties_dict):
def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
@ -1323,7 +1369,7 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the FROM expression.
@ -1356,7 +1402,7 @@ class Select(Subqueryable):
**opts,
)
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the GROUP BY expression.
@ -1392,7 +1438,7 @@ class Select(Subqueryable):
**opts,
)
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the ORDER BY expression.
@ -1425,7 +1471,7 @@ class Select(Subqueryable):
**opts,
)
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the SORT BY expression.
@ -1458,7 +1504,7 @@ class Select(Subqueryable):
**opts,
)
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the CLUSTER BY expression.
@ -1491,7 +1537,7 @@ class Select(Subqueryable):
**opts,
)
def limit(self, expression, dialect=None, copy=True, **opts):
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
@ -1522,7 +1568,7 @@ class Select(Subqueryable):
**opts,
)
def offset(self, expression, dialect=None, copy=True, **opts):
def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the OFFSET expression.
@ -1553,7 +1599,7 @@ class Select(Subqueryable):
**opts,
)
def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the SELECT expressions.
@ -1583,7 +1629,7 @@ class Select(Subqueryable):
**opts,
)
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the LATERAL expressions.
@ -1626,7 +1672,7 @@ class Select(Subqueryable):
dialect=None,
copy=True,
**opts,
):
) -> Select:
"""
Append to or set the JOIN expressions.
@ -1672,7 +1718,7 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if side:
@ -1681,12 +1727,12 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
on = and_(*ensure_list(on), dialect=dialect, **opts)
on = and_(*ensure_collection(on), dialect=dialect, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
*ensure_list(using),
*ensure_collection(using),
instance=join,
arg="using",
append=append,
@ -1705,7 +1751,7 @@ class Select(Subqueryable):
**opts,
)
def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the WHERE expressions.
@ -1737,7 +1783,7 @@ class Select(Subqueryable):
**opts,
)
def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the HAVING expressions.
@ -1769,7 +1815,7 @@ class Select(Subqueryable):
**opts,
)
def distinct(self, distinct=True, copy=True):
def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@ -1788,7 +1834,7 @@ class Select(Subqueryable):
instance.set("distinct", Distinct() if distinct else None)
return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
"""
Convert this expression to a CREATE TABLE AS statement.
@ -1826,11 +1872,11 @@ class Select(Subqueryable):
)
@property
def named_selects(self):
def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
@property
def selects(self):
def selects(self) -> t.List[Expression]:
return self.expressions
@ -1910,12 +1956,16 @@ class Parameter(Expression):
pass
class SessionParameter(Expression):
arg_types = {"this": True, "kind": False}
class Placeholder(Expression):
arg_types = {"this": False}
class Null(Condition):
arg_types = {}
arg_types: t.Dict[str, t.Any] = {}
class Boolean(Condition):
@ -1936,6 +1986,7 @@ class DataType(Expression):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
TINYINT = auto()
SMALLINT = auto()
@ -1975,7 +2026,7 @@ class DataType(Expression):
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
def build(cls, dtype, **kwargs):
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
pass
class NullSafeEQ(Binary, Predicate):
pass
class NullSafeNEQ(Binary, Predicate):
pass
class Distance(Binary):
pass
class Escape(Binary):
pass
@ -2101,18 +2164,14 @@ class Is(Binary, Predicate):
pass
class Kwarg(Binary):
"""Kwarg in special functions like func(kwarg => y)."""
class Like(Binary, Predicate):
pass
class SimilarTo(Binary, Predicate):
pass
class Distance(Binary):
pass
class LT(Binary, Predicate):
pass
@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
pass
class SimilarTo(Binary, Predicate):
pass
class Sub(Binary):
pass
@ -2189,7 +2252,13 @@ class Distinct(Expression):
class In(Predicate):
arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
arg_types = {
"this": True,
"expressions": False,
"query": False,
"unnest": False,
"field": False,
}
class TimeUnit(Expression):
@ -2255,7 +2324,9 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
raise NotImplementedError("SQL name is only supported by concrete function implementations")
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
class DateTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
class DateTrunc(Func):
arg_types = {"this": True, "expression": True, "zone": False}
class DatetimeAdd(Func, TimeUnit):
@ -2791,6 +2862,10 @@ class Year(Func):
pass
class Use(Expression):
pass
def _norm_args(expression):
args = {}
@ -2822,7 +2897,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
):
) -> t.Optional[Expression]:
"""Gracefully handle a possible string or expression.
Example:
@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
def select(*expressions, dialect=None, **opts):
def select(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
return Select().select(*expressions, dialect=dialect, **opts)
def from_(*expressions, dialect=None, **opts):
def from_(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts):
def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
"""
Creates an update statement.
@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
[
EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
for k, v in properties.items()
],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
return update
def delete(table, where=None, dialect=None, **opts):
def delete(table, where=None, dialect=None, **opts) -> Delete:
"""
Builds a delete statement.
@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
)
def condition(expression, dialect=None, **opts):
def condition(expression, dialect=None, **opts) -> Condition:
"""
Initialize a logical condition expression.
@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
Returns:
Condition: the expression
"""
return maybe_parse(
return maybe_parse( # type: ignore
expression,
into=Condition,
dialect=dialect,
@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
)
def and_(*expressions, dialect=None, **opts):
def and_(*expressions, dialect=None, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
return _combine(expressions, And, dialect, **opts)
def or_(*expressions, dialect=None, **opts):
def or_(*expressions, dialect=None, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
return _combine(expressions, Or, dialect, **opts)
def not_(expression, dialect=None, **opts):
def not_(expression, dialect=None, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
return Not(this=_wrap_operator(this))
def paren(expression):
def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None):
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None:
return None
if isinstance(alias, Identifier):
@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
return identifier
def to_table(sql_path: str, **kwargs) -> Table:
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned.
Args:
sql_path(str|Table): `[catalog].[schema].[table]` string
sql_path: a `[catalog].[schema].[table]` string.
Returns:
Table: A table expression
A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
return sql_path
@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts)
def column(col, table=None, quoted=None):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
)
def table_(table, db=None, catalog=None, quoted=None, alias=None):
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
Args:
@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
)
def values(values, alias=None):
def values(values, alias=None) -> Values:
"""Build VALUES statement.
Example:
@ -3449,7 +3527,7 @@ def values(values, alias=None):
)
def convert(value):
def convert(value) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
cns = ensure_list(fun(cn))
for child_node in cns:
for child_node in ensure_collection(fun(cn)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k
else:
new_child_nodes.append(cn)
expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression):
@ -3529,7 +3606,7 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
def table_name(table):
def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
@ -3546,6 +3623,9 @@ def table_name(table):
table = maybe_parse(table, into=Table)
if not table:
raise ValueError(f"Cannot parse {table}")
return ".".join(
part
for part in (