Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
3759c601a7
commit
96b10de29a
115 changed files with 66603 additions and 60920 deletions
|
@ -29,6 +29,7 @@ from sqlglot.helper import (
|
|||
camel_to_snake_case,
|
||||
ensure_collection,
|
||||
ensure_list,
|
||||
is_int,
|
||||
seq_get,
|
||||
subclasses,
|
||||
)
|
||||
|
@ -175,13 +176,7 @@ class Expression(metaclass=_Expression):
|
|||
"""
|
||||
Checks whether a Literal expression is an integer.
|
||||
"""
|
||||
if self.is_number:
|
||||
try:
|
||||
int(self.name)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
return self.is_number and is_int(self.name)
|
||||
|
||||
@property
|
||||
def is_star(self) -> bool:
|
||||
|
@ -493,8 +488,8 @@ class Expression(metaclass=_Expression):
|
|||
|
||||
A AND B AND C -> [A, B, C]
|
||||
"""
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
|
||||
if not type(node) is self.__class__:
|
||||
for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__):
|
||||
if type(node) is not self.__class__:
|
||||
yield node.unnest() if unnest and not isinstance(node, Subquery) else node
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -553,10 +548,12 @@ class Expression(metaclass=_Expression):
|
|||
return new_node
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: E) -> E: ...
|
||||
def replace(self, expression: E) -> E:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def replace(self, expression: None) -> None: ...
|
||||
def replace(self, expression: None) -> None:
|
||||
...
|
||||
|
||||
def replace(self, expression):
|
||||
"""
|
||||
|
@ -610,7 +607,8 @@ class Expression(metaclass=_Expression):
|
|||
>>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
|
||||
'SELECT x, z FROM y'
|
||||
"""
|
||||
assert isinstance(self, type_)
|
||||
if not isinstance(self, type_):
|
||||
raise AssertionError(f"{self} is not {type_}.")
|
||||
return self
|
||||
|
||||
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
|
||||
|
@ -1133,6 +1131,7 @@ class SetItem(Expression):
|
|||
class Show(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"history": False,
|
||||
"terse": False,
|
||||
"target": False,
|
||||
"offset": False,
|
||||
|
@ -1676,7 +1675,6 @@ class Index(Expression):
|
|||
"amp": False, # teradata
|
||||
"include": False,
|
||||
"partition_by": False, # teradata
|
||||
"where": False, # postgres partial indexes
|
||||
}
|
||||
|
||||
|
||||
|
@ -2573,7 +2571,7 @@ class HistoricalData(Expression):
|
|||
|
||||
class Table(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"this": False,
|
||||
"alias": False,
|
||||
"db": False,
|
||||
"catalog": False,
|
||||
|
@ -3664,6 +3662,7 @@ class DataType(Expression):
|
|||
BINARY = auto()
|
||||
BIT = auto()
|
||||
BOOLEAN = auto()
|
||||
BPCHAR = auto()
|
||||
CHAR = auto()
|
||||
DATE = auto()
|
||||
DATE32 = auto()
|
||||
|
@ -3805,6 +3804,7 @@ class DataType(Expression):
|
|||
dtype: DATA_TYPE,
|
||||
dialect: DialectType = None,
|
||||
udt: bool = False,
|
||||
copy: bool = True,
|
||||
**kwargs,
|
||||
) -> DataType:
|
||||
"""
|
||||
|
@ -3815,7 +3815,8 @@ class DataType(Expression):
|
|||
dialect: the dialect to use for parsing `dtype`, in case it's a string.
|
||||
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
|
||||
DataType, thus creating a user-defined type.
|
||||
kawrgs: additional arguments to pass in the constructor of DataType.
|
||||
copy: whether or not to copy the data type.
|
||||
kwargs: additional arguments to pass in the constructor of DataType.
|
||||
|
||||
Returns:
|
||||
The constructed DataType object.
|
||||
|
@ -3837,7 +3838,7 @@ class DataType(Expression):
|
|||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
elif isinstance(dtype, DataType):
|
||||
return dtype
|
||||
return maybe_copy(dtype, copy)
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||
|
||||
|
@ -3855,7 +3856,7 @@ class DataType(Expression):
|
|||
True, if and only if there is a type in `dtypes` which is equal to this DataType.
|
||||
"""
|
||||
for dtype in dtypes:
|
||||
other = DataType.build(dtype, udt=True)
|
||||
other = DataType.build(dtype, copy=False, udt=True)
|
||||
|
||||
if (
|
||||
other.expressions
|
||||
|
@ -4001,7 +4002,7 @@ class Dot(Binary):
|
|||
def build(self, expressions: t.Sequence[Expression]) -> Dot:
|
||||
"""Build a Dot object with a sequence of expressions."""
|
||||
if len(expressions) < 2:
|
||||
raise ValueError(f"Dot requires >= 2 expressions.")
|
||||
raise ValueError("Dot requires >= 2 expressions.")
|
||||
|
||||
return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions))
|
||||
|
||||
|
@ -4128,10 +4129,6 @@ class Sub(Binary):
|
|||
pass
|
||||
|
||||
|
||||
class ArrayOverlaps(Binary):
|
||||
pass
|
||||
|
||||
|
||||
# Unary Expressions
|
||||
# (NOT a)
|
||||
class Unary(Condition):
|
||||
|
@ -4469,6 +4466,10 @@ class ArrayJoin(Func):
|
|||
arg_types = {"this": True, "expression": True, "null": False}
|
||||
|
||||
|
||||
class ArrayOverlaps(Binary, Func):
|
||||
pass
|
||||
|
||||
|
||||
class ArraySize(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
@ -4490,15 +4491,37 @@ class Avg(AggFunc):
|
|||
|
||||
|
||||
class AnyValue(AggFunc):
|
||||
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
|
||||
arg_types = {"this": True, "having": False, "max": False}
|
||||
|
||||
|
||||
class First(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
class Lag(AggFunc):
|
||||
arg_types = {"this": True, "offset": False, "default": False}
|
||||
|
||||
|
||||
class Last(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
class Lead(AggFunc):
|
||||
arg_types = {"this": True, "offset": False, "default": False}
|
||||
|
||||
|
||||
# some dialects have a distinction between first and first_value, usually first is an aggregate func
|
||||
# and first_value is a window func
|
||||
class First(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class Last(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class FirstValue(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class LastValue(AggFunc):
|
||||
pass
|
||||
|
||||
|
||||
class NthValue(AggFunc):
|
||||
arg_types = {"this": True, "offset": True}
|
||||
|
||||
|
||||
class Case(Func):
|
||||
|
@ -4611,7 +4634,7 @@ class CurrentTime(Func):
|
|||
|
||||
|
||||
class CurrentTimestamp(Func):
|
||||
arg_types = {"this": False}
|
||||
arg_types = {"this": False, "transaction": False}
|
||||
|
||||
|
||||
class CurrentUser(Func):
|
||||
|
@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit):
|
|||
|
||||
|
||||
class TimestampDiff(Func, TimeUnit):
|
||||
_sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"]
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
|
@ -4857,6 +4881,59 @@ class IsInf(Func):
|
|||
_sql_names = ["IS_INF", "ISINF"]
|
||||
|
||||
|
||||
class JSONPath(Expression):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
last_segment = self.expressions[-1].this
|
||||
return last_segment if isinstance(last_segment, str) else ""
|
||||
|
||||
|
||||
class JSONPathPart(Expression):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class JSONPathFilter(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathKey(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathRecursive(JSONPathPart):
|
||||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class JSONPathRoot(JSONPathPart):
|
||||
pass
|
||||
|
||||
|
||||
class JSONPathScript(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathSlice(JSONPathPart):
|
||||
arg_types = {"start": False, "end": False, "step": False}
|
||||
|
||||
|
||||
class JSONPathSelector(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathSubscript(JSONPathPart):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class JSONPathUnion(JSONPathPart):
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
|
||||
class JSONPathWildcard(JSONPathPart):
|
||||
pass
|
||||
|
||||
|
||||
class FormatJson(Expression):
|
||||
pass
|
||||
|
||||
|
@ -4940,18 +5017,30 @@ class JSONBContains(Binary):
|
|||
|
||||
|
||||
class JSONExtract(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "expressions": False}
|
||||
_sql_names = ["JSON_EXTRACT"]
|
||||
is_var_len_args = True
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name if not self.expressions else ""
|
||||
|
||||
|
||||
class JSONExtractScalar(JSONExtract):
|
||||
class JSONExtractScalar(Binary, Func):
|
||||
arg_types = {"this": True, "expression": True, "expressions": False}
|
||||
_sql_names = ["JSON_EXTRACT_SCALAR"]
|
||||
is_var_len_args = True
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name
|
||||
|
||||
|
||||
class JSONBExtract(JSONExtract):
|
||||
class JSONBExtract(Binary, Func):
|
||||
_sql_names = ["JSONB_EXTRACT"]
|
||||
|
||||
|
||||
class JSONBExtractScalar(JSONExtract):
|
||||
class JSONBExtractScalar(Binary, Func):
|
||||
_sql_names = ["JSONB_EXTRACT_SCALAR"]
|
||||
|
||||
|
||||
|
@ -4972,15 +5061,6 @@ class ParseJSON(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/get_path
|
||||
class GetPath(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.expression.output_name
|
||||
|
||||
|
||||
class Least(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
@ -5476,6 +5556,8 @@ def _norm_arg(arg):
|
|||
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
|
||||
FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()}
|
||||
|
||||
JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,))
|
||||
|
||||
|
||||
# Helpers
|
||||
@t.overload
|
||||
|
@ -5487,7 +5569,8 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E: ...
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
|
@ -5499,7 +5582,8 @@ def maybe_parse(
|
|||
prefix: t.Optional[str] = None,
|
||||
copy: bool = False,
|
||||
**opts,
|
||||
) -> E: ...
|
||||
) -> E:
|
||||
...
|
||||
|
||||
|
||||
def maybe_parse(
|
||||
|
@ -5539,7 +5623,7 @@ def maybe_parse(
|
|||
return sql_or_expression
|
||||
|
||||
if sql_or_expression is None:
|
||||
raise ParseError(f"SQL cannot be None")
|
||||
raise ParseError("SQL cannot be None")
|
||||
|
||||
import sqlglot
|
||||
|
||||
|
@ -5551,11 +5635,13 @@ def maybe_parse(
|
|||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None: ...
|
||||
def maybe_copy(instance: None, copy: bool = True) -> None:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E: ...
|
||||
def maybe_copy(instance: E, copy: bool = True) -> E:
|
||||
...
|
||||
|
||||
|
||||
def maybe_copy(instance, copy=True):
|
||||
|
@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren:
|
|||
return Paren(this=maybe_parse(expression, copy=copy))
|
||||
|
||||
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ...
|
||||
def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_identifier(
|
||||
name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True
|
||||
) -> Identifier: ...
|
||||
) -> Identifier:
|
||||
...
|
||||
|
||||
|
||||
def to_identifier(name, quoted=None, copy=True):
|
||||
|
@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval:
|
|||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table: ...
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: None, **kwargs) -> None: ...
|
||||
def to_table(sql_path: None, **kwargs) -> None:
|
||||
...
|
||||
|
||||
|
||||
def to_table(
|
||||
|
@ -6460,7 +6550,7 @@ def column(
|
|||
return this
|
||||
|
||||
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
||||
def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
|
@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
|
|||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The datatype to cast to.
|
||||
copy: Whether or not to copy the supplied expressions.
|
||||
|
||||
Returns:
|
||||
The new Cast instance.
|
||||
"""
|
||||
expression = maybe_parse(expression, **opts)
|
||||
data_type = DataType.build(to, **opts)
|
||||
expression = maybe_parse(expression, copy=copy, **opts)
|
||||
data_type = DataType.build(to, copy=copy, **opts)
|
||||
expression = Cast(this=expression, to=data_type)
|
||||
expression.type = data_type
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue