1
0
Fork 0

Merging upstream version 10.6.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:09:58 +01:00
parent d03a55eda6
commit ece6881255
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
48 changed files with 906 additions and 266 deletions

View file

@ -32,13 +32,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
from sqlglot.dialects.dialect import DialectType
class _Expression(type):
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
def sql(self, dialect: DialectType = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
return load(obj)
if t.TYPE_CHECKING:
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
"""
@ -1285,6 +1287,18 @@ class Property(Expression):
arg_types = {"this": True, "value": True}
class AlgorithmProperty(Property):
arg_types = {"this": True}
class DefinerProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property):
arg_types = {"this": True}
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
class Properties(Expression):
arg_types = {"expressions": True, "before": False}
arg_types = {"expressions": True}
NAME_TO_PROPERTY = {
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
"DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty,
@ -1447,6 +1463,14 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
class Location(AutoName):
POST_CREATE = auto()
PRE_SCHEMA = auto()
POST_INDEX = auto()
POST_SCHEMA_ROOT = auto()
POST_SCHEMA_WITH = auto()
UNSUPPORTED = auto()
@classmethod
def from_dict(cls, properties_dict) -> Properties:
expressions = []
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
"order": False,
"limit": False,
"offset": False,
"lock": False,
}
@ -1713,6 +1738,12 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False}
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
class Lock(Expression):
arg_types = {"update": True}
class Select(Subqueryable):
arg_types = {
"with": False,
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
properties=properties_expression,
)
def lock(self, update: bool = True, copy: bool = True) -> Select:
"""
Set the locking read mode for this expression.
Examples:
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
Args:
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
copy: if `False`, modify this expression instance in-place.
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst.set("lock", Lock(update=update))
return inst
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@ -2456,24 +2511,28 @@ class DataType(Expression):
@classmethod
def build(
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
) -> DataType:
from sqlglot import parse_one
if isinstance(dtype, str):
data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__:
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
return dtype
else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, dtype: DataType.Type) -> bool:
return self.this == dtype
# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression):
@ -2840,6 +2899,10 @@ class Array(Func):
is_var_len_args = True
class GenerateSeries(Func):
arg_types = {"start": True, "end": True, "step": False}
class ArrayAgg(AggFunc):
pass
@ -2909,6 +2972,9 @@ class Cast(Func):
def output_name(self):
return self.name
def is_type(self, dtype: DataType.Type) -> bool:
return self.to.is_type(dtype)
class Collate(Binary):
pass
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
class DayOfWeek(Func):
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
class DayOfMonth(Func):
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
class DayOfYear(Func):
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
class LastDateOfMonth(Func):
pass
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
class RegexpLike(Func):
@ -3476,7 +3558,7 @@ def maybe_parse(
sql_or_expression: str | Expression,
*,
into: t.Optional[IntoType] = None,
dialect: t.Optional[str] = None,
dialect: DialectType = None,
prefix: t.Optional[str] = None,
**opts,
) -> Expression:
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
def to_interval(interval: str | Literal) -> Interval:
"""Builds an interval expression from a string like '1 day' or '5 months'."""
if isinstance(interval, Literal):
if not interval.is_string:
raise ValueError("Invalid interval string.")
interval = interval.this
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
if not interval_parts:
raise ValueError("Invalid interval string.")
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2)),
)
@t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table:
...
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
def subquery(expression, alias=None, dialect=None, **opts):
"""
Build a subquery expression.
Expample:
Example:
>>> subquery('select x from tbl', 'bar').select('x').sql()
'SELECT x FROM (SELECT x FROM tbl) AS bar'
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
col (str | Expression): column name
table (str | Expression): table name
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
)
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
>>> cast('x + 1', 'int').sql()
'CAST(x + 1 AS INT)'
Args:
expression: The expression to cast.
to: The datatype to cast to.
Returns:
A cast node.
"""
expression = maybe_parse(expression, **opts)
return Cast(this=expression, to=DataType.build(to, **opts))
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
@ -4137,7 +4261,7 @@ def values(
types = list(columns.values())
expressions[0].set(
"expressions",
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
)
return Values(
expressions=expressions,
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
return expression.transform(_expand, copy=copy)
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
"""
Returns a Func expression.