Merging upstream version 10.6.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d03a55eda6
commit
ece6881255
48 changed files with 906 additions and 266 deletions
|
@ -32,13 +32,7 @@ from sqlglot.helper import (
|
|||
from sqlglot.tokens import Token
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
class _Expression(type):
|
||||
|
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
|
|||
def __repr__(self):
|
||||
return self._to_s()
|
||||
|
||||
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
||||
def sql(self, dialect: DialectType = None, **opts) -> str:
|
||||
"""
|
||||
Returns SQL string representation of this tree.
|
||||
|
||||
|
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
|
|||
return load(obj)
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
IntoType = t.Union[
|
||||
str,
|
||||
t.Type[Expression],
|
||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||
]
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
def and_(self, *expressions, dialect=None, **opts):
|
||||
"""
|
||||
|
@ -1285,6 +1287,18 @@ class Property(Expression):
|
|||
arg_types = {"this": True, "value": True}
|
||||
|
||||
|
||||
class AlgorithmProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class DefinerProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SqlSecurityProperty(Property):
|
||||
arg_types = {"definer": True}
|
||||
|
||||
|
||||
class TableFormatProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
|
|||
|
||||
|
||||
class Properties(Expression):
|
||||
arg_types = {"expressions": True, "before": False}
|
||||
arg_types = {"expressions": True}
|
||||
|
||||
NAME_TO_PROPERTY = {
|
||||
"ALGORITHM": AlgorithmProperty,
|
||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||
"CHARACTER SET": CharacterSetProperty,
|
||||
"COLLATE": CollateProperty,
|
||||
"COMMENT": SchemaCommentProperty,
|
||||
"DEFINER": DefinerProperty,
|
||||
"DISTKEY": DistKeyProperty,
|
||||
"DISTSTYLE": DistStyleProperty,
|
||||
"ENGINE": EngineProperty,
|
||||
|
@ -1447,6 +1463,14 @@ class Properties(Expression):
|
|||
|
||||
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||
|
||||
class Location(AutoName):
|
||||
POST_CREATE = auto()
|
||||
PRE_SCHEMA = auto()
|
||||
POST_INDEX = auto()
|
||||
POST_SCHEMA_ROOT = auto()
|
||||
POST_SCHEMA_WITH = auto()
|
||||
UNSUPPORTED = auto()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, properties_dict) -> Properties:
|
||||
expressions = []
|
||||
|
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
|
|||
"order": False,
|
||||
"limit": False,
|
||||
"offset": False,
|
||||
"lock": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1713,6 +1738,12 @@ class Schema(Expression):
|
|||
arg_types = {"this": False, "expressions": False}
|
||||
|
||||
|
||||
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
|
||||
class Lock(Expression):
|
||||
arg_types = {"update": True}
|
||||
|
||||
|
||||
class Select(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
|
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
|
|||
properties=properties_expression,
|
||||
)
|
||||
|
||||
def lock(self, update: bool = True, copy: bool = True) -> Select:
|
||||
"""
|
||||
Set the locking read mode for this expression.
|
||||
|
||||
Examples:
|
||||
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
|
||||
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
|
||||
|
||||
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
|
||||
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
|
||||
|
||||
Args:
|
||||
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
|
||||
copy: if `False`, modify this expression instance in-place.
|
||||
|
||||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst.set("lock", Lock(update=update))
|
||||
|
||||
return inst
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [e.output_name for e in self.expressions if e.alias_or_name]
|
||||
|
@ -2456,24 +2511,28 @@ class DataType(Expression):
|
|||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
|
||||
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
|
||||
) -> DataType:
|
||||
from sqlglot import parse_one
|
||||
|
||||
if isinstance(dtype, str):
|
||||
data_type_exp: t.Optional[Expression]
|
||||
if dtype.upper() in cls.Type.__members__:
|
||||
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
|
||||
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
|
||||
else:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
if data_type_exp is None:
|
||||
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
elif isinstance(dtype, DataType):
|
||||
return dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.this == dtype
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
class PseudoType(Expression):
|
||||
|
@ -2840,6 +2899,10 @@ class Array(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class GenerateSeries(Func):
|
||||
arg_types = {"start": True, "end": True, "step": False}
|
||||
|
||||
|
||||
class ArrayAgg(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -2909,6 +2972,9 @@ class Cast(Func):
|
|||
def output_name(self):
|
||||
return self.name
|
||||
|
||||
def is_type(self, dtype: DataType.Type) -> bool:
|
||||
return self.to.is_type(dtype)
|
||||
|
||||
|
||||
class Collate(Binary):
|
||||
pass
|
||||
|
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
|
|||
arg_types = {"this": True, "unit": True, "zone": False}
|
||||
|
||||
|
||||
class DayOfWeek(Func):
|
||||
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
|
||||
|
||||
|
||||
class DayOfMonth(Func):
|
||||
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
|
||||
|
||||
|
||||
class DayOfYear(Func):
|
||||
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
|
||||
|
||||
|
||||
class WeekOfYear(Func):
|
||||
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
||||
|
||||
|
||||
class LastDateOfMonth(Func):
|
||||
pass
|
||||
|
||||
|
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
|
|||
|
||||
|
||||
class Reduce(Func):
|
||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
|
||||
|
||||
|
||||
class RegexpLike(Func):
|
||||
|
@ -3476,7 +3558,7 @@ def maybe_parse(
|
|||
sql_or_expression: str | Expression,
|
||||
*,
|
||||
into: t.Optional[IntoType] = None,
|
||||
dialect: t.Optional[str] = None,
|
||||
dialect: DialectType = None,
|
||||
prefix: t.Optional[str] = None,
|
||||
**opts,
|
||||
) -> Expression:
|
||||
|
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
|||
return identifier
|
||||
|
||||
|
||||
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
|
||||
|
||||
|
||||
def to_interval(interval: str | Literal) -> Interval:
|
||||
"""Builds an interval expression from a string like '1 day' or '5 months'."""
|
||||
if isinstance(interval, Literal):
|
||||
if not interval.is_string:
|
||||
raise ValueError("Invalid interval string.")
|
||||
|
||||
interval = interval.this
|
||||
|
||||
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
|
||||
|
||||
if not interval_parts:
|
||||
raise ValueError("Invalid interval string.")
|
||||
|
||||
return Interval(
|
||||
this=Literal.string(interval_parts.group(1)),
|
||||
unit=Var(this=interval_parts.group(2)),
|
||||
)
|
||||
|
||||
|
||||
@t.overload
|
||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||
...
|
||||
|
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
|||
def subquery(expression, alias=None, dialect=None, **opts):
|
||||
"""
|
||||
Build a subquery expression.
|
||||
Expample:
|
||||
|
||||
Example:
|
||||
>>> subquery('select x from tbl', 'bar').select('x').sql()
|
||||
'SELECT x FROM (SELECT x FROM tbl) AS bar'
|
||||
|
||||
|
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
|||
def column(col, table=None, quoted=None) -> Column:
|
||||
"""
|
||||
Build a Column.
|
||||
|
||||
Args:
|
||||
col (str | Expression): column name
|
||||
table (str | Expression): table name
|
||||
|
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
|
|||
)
|
||||
|
||||
|
||||
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||
"""Cast an expression to a data type.
|
||||
|
||||
Example:
|
||||
>>> cast('x + 1', 'int').sql()
|
||||
'CAST(x + 1 AS INT)'
|
||||
|
||||
Args:
|
||||
expression: The expression to cast.
|
||||
to: The datatype to cast to.
|
||||
|
||||
Returns:
|
||||
A cast node.
|
||||
"""
|
||||
expression = maybe_parse(expression, **opts)
|
||||
return Cast(this=expression, to=DataType.build(to, **opts))
|
||||
|
||||
|
||||
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||
"""Build a Table.
|
||||
|
||||
|
@ -4137,7 +4261,7 @@ def values(
|
|||
types = list(columns.values())
|
||||
expressions[0].set(
|
||||
"expressions",
|
||||
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||
)
|
||||
return Values(
|
||||
expressions=expressions,
|
||||
|
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
|
|||
return expression.transform(_expand, copy=copy)
|
||||
|
||||
|
||||
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
|
||||
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
||||
"""
|
||||
Returns a Func expression.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue