1
0
Fork 0

Merging upstream version 21.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:20:36 +01:00
parent 3759c601a7
commit 96b10de29a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
115 changed files with 66603 additions and 60920 deletions

View file

@ -29,6 +29,7 @@ from sqlglot.helper import (
camel_to_snake_case,
ensure_collection,
ensure_list,
is_int,
seq_get,
subclasses,
)
@ -175,13 +176,7 @@ class Expression(metaclass=_Expression):
"""
Checks whether a Literal expression is an integer.
"""
if self.is_number:
try:
int(self.name)
return True
except ValueError:
pass
return False
return self.is_number and is_int(self.name)
@property
def is_star(self) -> bool:
@ -493,8 +488,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
if not type(node) is self.__class__:
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
if type(node) is not self.__class__:
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
def __str__(self) -> str:
@ -553,10 +548,12 @@ class Expression(metaclass=_Expression):
return new_node
@t.overload
def replace(self, expression: E) -> E: ...
def replace(self, expression: E) -> E:
...
@t.overload
def replace(self, expression: None) -> None: ...
def replace(self, expression: None) -> None:
...
def replace(self, expression):
"""
@ -610,7 +607,8 @@ class Expression(metaclass=_Expression):
>>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
'SELECT x, z FROM y'
"""
assert isinstance(self, type_)
if not isinstance(self, type_):
raise AssertionError(f"{self} is not {type_}.")
return self
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
@ -1133,6 +1131,7 @@ class SetItem(Expression):
class Show(Expression):
arg_types = {
"this": True,
"history": False,
"terse": False,
"target": False,
"offset": False,
@ -1676,7 +1675,6 @@ class Index(Expression):
"amp": False, # teradata
"include": False,
"partition_by": False, # teradata
"where": False, # postgres partial indexes
}
@ -2573,7 +2571,7 @@ class HistoricalData(Expression):
class Table(Expression):
arg_types = {
"this": True,
"this": False,
"alias": False,
"db": False,
"catalog": False,
@ -3664,6 +3662,7 @@ class DataType(Expression):
BINARY = auto()
BIT = auto()
BOOLEAN = auto()
BPCHAR = auto()
CHAR = auto()
DATE = auto()
DATE32 = auto()
@ -3805,6 +3804,7 @@ class DataType(Expression):
dtype: DATA_TYPE,
dialect: DialectType = None,
udt: bool = False,
copy: bool = True,
**kwargs,
) -> DataType:
"""
@ -3815,7 +3815,8 @@ class DataType(Expression):
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
kawrgs: additional arguments to pass in the constructor of DataType.
copy: whether or not to copy the data type.
kwargs: additional arguments to pass in the constructor of DataType.
Returns:
The constructed DataType object.
@ -3837,7 +3838,7 @@ class DataType(Expression):
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
return dtype
return maybe_copy(dtype, copy)
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
@ -3855,7 +3856,7 @@ class DataType(Expression):
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
for dtype in dtypes:
other = DataType.build(dtype, udt=True)
other = DataType.build(dtype, copy=False, udt=True)
if (
other.expressions
@ -4001,7 +4002,7 @@ class Dot(Binary):
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
if len(expressions) < 2:
raise ValueError(f"Dot requires >= 2 expressions.")
raise ValueError("Dot requires >= 2 expressions.")
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
@ -4128,10 +4129,6 @@ class Sub(Binary):
pass
class ArrayOverlaps(Binary):
pass
# Unary Expressions
# (NOT a)
class Unary(Condition):
@ -4469,6 +4466,10 @@ class ArrayJoin(Func):
arg_types = {"this": True, "expression": True, "null": False}
class ArrayOverlaps(Binary, Func):
pass
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
@ -4490,15 +4491,37 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
arg_types = {"this": True, "having": False, "max": False}
class First(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Lag(AggFunc):
arg_types = {"this": True, "offset": False, "default": False}
class Last(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Lead(AggFunc):
arg_types = {"this": True, "offset": False, "default": False}
# some dialects have a distinction between first and first_value, usually first is an aggregate func
# and first_value is a window func
class First(AggFunc):
pass
class Last(AggFunc):
pass
class FirstValue(AggFunc):
pass
class LastValue(AggFunc):
pass
class NthValue(AggFunc):
arg_types = {"this": True, "offset": True}
class Case(Func):
@ -4611,7 +4634,7 @@ class CurrentTime(Func):
class CurrentTimestamp(Func):
arg_types = {"this": False}
arg_types = {"this": False, "transaction": False}
class CurrentUser(Func):
@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit):
class TimestampDiff(Func, TimeUnit):
_sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"]
arg_types = {"this": True, "expression": True, "unit": False}
@ -4857,6 +4881,59 @@ class IsInf(Func):
_sql_names = ["IS_INF", "ISINF"]
class JSONPath(Expression):
arg_types = {"expressions": True}
@property
def output_name(self) -> str:
last_segment = self.expressions[-1].this
return last_segment if isinstance(last_segment, str) else ""
class JSONPathPart(Expression):
arg_types = {}
class JSONPathFilter(JSONPathPart):
arg_types = {"this": True}
class JSONPathKey(JSONPathPart):
arg_types = {"this": True}
class JSONPathRecursive(JSONPathPart):
arg_types = {"this": False}
class JSONPathRoot(JSONPathPart):
pass
class JSONPathScript(JSONPathPart):
arg_types = {"this": True}
class JSONPathSlice(JSONPathPart):
arg_types = {"start": False, "end": False, "step": False}
class JSONPathSelector(JSONPathPart):
arg_types = {"this": True}
class JSONPathSubscript(JSONPathPart):
arg_types = {"this": True}
class JSONPathUnion(JSONPathPart):
arg_types = {"expressions": True}
class JSONPathWildcard(JSONPathPart):
pass
class FormatJson(Expression):
pass
@ -4940,18 +5017,30 @@ class JSONBContains(Binary):
class JSONExtract(Binary, Func):
arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT"]
is_var_len_args = True
@property
def output_name(self) -> str:
return self.expression.output_name if not self.expressions else ""
class JSONExtractScalar(JSONExtract):
class JSONExtractScalar(Binary, Func):
arg_types = {"this": True, "expression": True, "expressions": False}
_sql_names = ["JSON_EXTRACT_SCALAR"]
is_var_len_args = True
@property
def output_name(self) -> str:
return self.expression.output_name
class JSONBExtract(JSONExtract):
class JSONBExtract(Binary, Func):
_sql_names = ["JSONB_EXTRACT"]
class JSONBExtractScalar(JSONExtract):
class JSONBExtractScalar(Binary, Func):
_sql_names = ["JSONB_EXTRACT_SCALAR"]
@ -4972,15 +5061,6 @@ class ParseJSON(Func):
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/get_path
class GetPath(Func):
arg_types = {"this": True, "expression": True}
@property
def output_name(self) -> str:
return self.expression.output_name
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -5476,6 +5556,8 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
# Helpers
@t.overload
@ -5487,7 +5569,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E: ...
) -> E:
...
@t.overload
@ -5499,7 +5582,8 @@ def maybe_parse(
prefix: t.Optional[str] = None,
copy: bool = False,
**opts,
) -> E: ...
) -> E:
...
def maybe_parse(
@ -5539,7 +5623,7 @@ def maybe_parse(
return sql_or_expression
if sql_or_expression is None:
raise ParseError(f"SQL cannot be None")
raise ParseError("SQL cannot be None")
import sqlglot
@ -5551,11 +5635,13 @@ def maybe_parse(
@t.overload
def maybe_copy(instance: None, copy: bool = True) -> None: ...
def maybe_copy(instance: None, copy: bool = True) -> None:
...
@t.overload
def maybe_copy(instance: E, copy: bool = True) -> E: ...
def maybe_copy(instance: E, copy: bool = True) -> E:
...
def maybe_copy(instance, copy=True):
@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
return Paren(this=maybe_parse(expression, copy=copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
@t.overload
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
...
@t.overload
def to_identifier(
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
) -> Identifier: ...
) -> Identifier:
...
def to_identifier(name, quoted=None, copy=True):
@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval:
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@t.overload
def to_table(sql_path: None, **kwargs) -> None: ...
def to_table(sql_path: None, **kwargs) -> None:
...
def to_table(
@ -6460,7 +6550,7 @@ def column(
return this
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
Args:
expression: The expression to cast.
to: The datatype to cast to.
copy: Whether or not to copy the supplied expressions.
Returns:
The new Cast instance.
"""
expression = maybe_parse(expression, **opts)
data_type = DataType.build(to, **opts)
expression = maybe_parse(expression, copy=copy, **opts)
data_type = DataType.build(to, copy=copy, **opts)
expression = Cast(this=expression, to=data_type)
expression.type = data_type
return expression