1
0
Fork 0

Merging upstream version 10.5.10.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:07:05 +01:00
parent 8588db6332
commit 4d496b7a6a
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
43 changed files with 1384 additions and 356 deletions

View file

@ -1,5 +1,12 @@
"""
.. include:: ../pdoc/docs/expressions.md
## Expressions
Every AST node in SQLGlot is represented by a subclass of `Expression`.
This module contains the implementation of all supported `Expression` types. Additionally,
it exposes a number of helper functions, which are mainly used to programmatically build
SQL expressions, such as `sqlglot.expressions.select`.
----
"""
from __future__ import annotations
@ -27,35 +34,66 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
# When an Expression class is created, its key is automatically set to be
# the lowercase version of the class' name.
klass.key = clsname.lower()
# This is so that docstrings are not inherited in pdoc
klass.__doc__ = klass.__doc__ or ""
return klass
class Expression(metaclass=_Expression):
"""
The base class for all expressions in a syntax tree.
The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
context, such as its child expressions, their names (arg keys), and whether a given child expression
is optional or not.
Attributes:
arg_types (dict): determines arguments supported by this expression.
The key in a dictionary defines a unique key of an argument using
which the argument's value can be retrieved. The value is a boolean
flag which indicates whether the argument's value is required (True)
or optional (False).
key: a unique key for each class in the Expression hierarchy. This is useful for hashing
and representing expressions as strings.
arg_types: determines what arguments (child nodes) are supported by an expression. It
maps arg keys to booleans that indicate whether the corresponding args are optional.
Example:
>>> class Foo(Expression):
... arg_types = {"this": True, "expression": False}
The above definition informs us that Foo is an Expression that requires an argument called
"this" and may also optionally receive an argument called "expression".
Args:
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
_type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
"""
key = "Expression"
key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.comments = None
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
self.parent: t.Optional[Expression] = None
self.arg_key: t.Optional[str] = None
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
@ -76,17 +114,30 @@ class Expression(metaclass=_Expression):
@property
def this(self):
"""
Retrieves the argument with key "this".
"""
return self.args.get("this")
@property
def expression(self):
"""
Retrieves the argument with key "expression".
"""
return self.args.get("expression")
@property
def expressions(self):
"""
Retrieves the argument with key "expressions".
"""
return self.args.get("expressions") or []
def text(self, key):
"""
Returns a textual representation of the argument corresponding to "key". This can only be used
for args that are strings or leaf Expression instances, such as identifiers and literals.
"""
field = self.args.get(key)
if isinstance(field, str):
return field
@ -96,14 +147,23 @@ class Expression(metaclass=_Expression):
@property
def is_string(self):
"""
Checks whether a Literal expression is a string.
"""
return isinstance(self, Literal) and self.args["is_string"]
@property
def is_number(self):
"""
Checks whether a Literal expression is a number.
"""
return isinstance(self, Literal) and not self.args["is_string"]
@property
def is_int(self):
"""
Checks whether a Literal expression is an integer.
"""
if self.is_number:
try:
int(self.name)
@ -114,6 +174,9 @@ class Expression(metaclass=_Expression):
@property
def alias(self):
"""
Returns the alias of the expression, or an empty string if it's not aliased.
"""
if isinstance(self.args.get("alias"), TableAlias):
return self.args["alias"].name
return self.text("alias")
@ -128,6 +191,24 @@ class Expression(metaclass=_Expression):
return "NULL"
return self.alias or self.name
@property
def output_name(self):
"""
Name of the output column if this expression is a selection.
If the Expression has no output name, an empty string is returned.
Example:
>>> from sqlglot import parse_one
>>> parse_one("SELECT a").expressions[0].output_name
'a'
>>> parse_one("SELECT b AS c").expressions[0].output_name
'c'
>>> parse_one("SELECT 1 + 2").expressions[0].output_name
''
"""
return ""
@property
def type(self) -> t.Optional[DataType]:
return self._type
@ -145,6 +226,9 @@ class Expression(metaclass=_Expression):
return copy
def copy(self):
"""
Returns a deep copy of the expression.
"""
new = deepcopy(self)
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
@ -169,7 +253,7 @@ class Expression(metaclass=_Expression):
Sets `arg_key` to `value`.
Args:
arg_key (str): name of the expression arg
arg_key (str): name of the expression arg.
value: value to set the arg to.
"""
self.args[arg_key] = value
@ -203,8 +287,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
the node which matches the criteria or None if no node matching
the criteria was found.
The node which matches the criteria or None if no such node was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
@ -217,7 +300,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
the generator object.
The generator object.
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
@ -231,7 +314,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
the parent node
The parent node.
"""
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
@ -269,7 +352,7 @@ class Expression(metaclass=_Expression):
the DFS (Depth-first) order.
Returns:
the generator object.
The generator object.
"""
parent = parent or self.parent
yield self, parent, key
@ -287,7 +370,7 @@ class Expression(metaclass=_Expression):
the BFS (Breadth-first) order.
Returns:
the generator object.
The generator object.
"""
queue = deque([(self, self.parent, None)])
@ -341,32 +424,33 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self):
return self.to_s()
return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
Args
dialect (str): the dialect of the output SQL string
(eg. "spark", "hive", "presto", "mysql").
opts (dict): other :class:`~sqlglot.generator.Generator` options.
Args:
dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql").
opts: other `sqlglot.generator.Generator` options.
Returns
the SQL string.
Returns:
The SQL string.
"""
from sqlglot.dialects import Dialect
return Dialect.get_or_raise(dialect)().generate(self, **opts)
def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
for v in ensure_collection(vs)
if v is not None
)
@ -394,7 +478,7 @@ class Expression(metaclass=_Expression):
modified in place.
Returns:
the transformed tree.
The transformed tree.
"""
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
@ -423,8 +507,8 @@ class Expression(metaclass=_Expression):
Args:
expression (Expression|None): new node
Returns :
the new expression or expressions
Returns:
The new expression or expressions.
"""
if not self.parent:
return expression
@ -458,6 +542,40 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
"""
Checks if this expression is valid (e.g. all mandatory args are set).
Args:
args: a sequence of values that were used to instantiate a Func expression. This is used
to check that the provided arguments don't exceed the function argument limit.
Returns:
A list of error messages for all possible errors that were found.
"""
errors: t.List[str] = []
for k in self.args:
if k not in self.arg_types:
errors.append(f"Unexpected keyword: '{k}' for {self.__class__}")
for k, mandatory in self.arg_types.items():
v = self.args.get(k)
if mandatory and (v is None or (isinstance(v, list) and not v)):
errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
if (
args
and isinstance(self, Func)
and len(args) > len(self.arg_types)
and not self.is_var_len_args
):
errors.append(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(self.arg_types)})"
)
return errors
def dump(self):
"""
Dump this Expression to a JSON-serializable dict.
@ -552,7 +670,7 @@ class DerivedTable(Expression):
@property
def named_selects(self):
return [select.alias_or_name for select in self.selects]
return [select.output_name for select in self.selects]
class Unionable(Expression):
@ -654,6 +772,7 @@ class Create(Expression):
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
}
@ -696,7 +815,7 @@ class Show(Expression):
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
arg_types = {"this": True, "expressions": False, "wrapped": False}
class UserDefinedFunctionKwarg(Expression):
@ -750,6 +869,10 @@ class Column(Condition):
def table(self):
return self.text("table")
@property
def output_name(self):
return self.name
class ColumnDef(Expression):
arg_types = {
@ -865,6 +988,10 @@ class ForeignKey(Expression):
}
class PrimaryKey(Expression):
arg_types = {"expressions": True, "options": False}
class Unique(Expression):
arg_types = {"expressions": True}
@ -904,6 +1031,10 @@ class Identifier(Expression):
def __hash__(self):
return hash((self.key, self.this.lower()))
@property
def output_name(self):
return self.name
class Index(Expression):
arg_types = {
@ -996,6 +1127,10 @@ class Literal(Condition):
def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
@property
def output_name(self):
return self.name
class Join(Expression):
arg_types = {
@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property):
class ReturnsProperty(Property):
arg_types = {"this": True, "is_table": False}
arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property):
@ -1262,8 +1397,13 @@ class Qualify(Expression):
pass
# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
class Return(Expression):
pass
class Reference(Expression):
arg_types = {"this": True, "expressions": True}
arg_types = {"this": True, "expressions": False, "options": False}
class Tuple(Expression):
@ -1397,6 +1537,16 @@ class Table(Expression):
"joins": False,
"pivots": False,
"hints": False,
"system_time": False,
}
# See the TSQL "Querying data in a system-versioned temporal table" page
class SystemTime(Expression):
arg_types = {
"this": False,
"expression": False,
"kind": True,
}
@ -2027,7 +2177,7 @@ class Select(Subqueryable):
@property
def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
return [e.output_name for e in self.expressions if e.alias_or_name]
@property
def selects(self) -> t.List[Expression]:
@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
@property
def output_name(self):
return self.alias
class TableSample(Expression):
arg_types = {
@ -2066,6 +2220,16 @@ class TableSample(Expression):
}
class Tag(Expression):
"""Tags are used for generating arbitrary sql like SELECT <span>x</span>."""
arg_types = {
"this": False,
"prefix": False,
"postfix": False,
}
class Pivot(Expression):
arg_types = {
"this": False,
@ -2106,6 +2270,10 @@ class Star(Expression):
def name(self):
return "*"
@property
def output_name(self):
return self.name
class Parameter(Expression):
pass
@ -2143,6 +2311,8 @@ class DataType(Expression):
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
MEDIUMBLOB = auto()
LONGBLOB = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@ -2282,11 +2452,11 @@ class Rollback(Expression):
class AlterTable(Expression):
arg_types = {
"this": True,
"actions": True,
"exists": False,
}
arg_types = {"this": True, "actions": True, "exists": False}
class AddConstraint(Expression):
arg_types = {"this": False, "expression": False, "enforced": False}
# Binary expressions like (ADD a b)
@ -2456,6 +2626,10 @@ class Neg(Unary):
class Alias(Expression):
arg_types = {"this": True, "alias": False}
@property
def output_name(self):
return self.alias
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@ -2523,16 +2697,13 @@ class Func(Condition):
"""
The base class for all function expressions.
Attributes
is_var_len_args (bool): if set to True the last argument defined in
arg_types will be treated as a variable length argument and the
argument's value will be stored as a list.
_sql_names (list): determines the SQL name (1st item in the list) and
aliases (subsequent items) for this function expression. These
values are used to map this node to a name during parsing as well
as to provide the function's name during SQL string generation. By
default the SQL name is set to the expression's class name transformed
to snake case.
Attributes:
is_var_len_args (bool): if set to True the last argument defined in arg_types will be
treated as a variable length argument and the argument's value will be stored as a list.
_sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
for this function expression. These values are used to map this node to a name during parsing
as well as to provide the function's name during SQL string generation. By default the SQL
name is set to the expression's class name transformed to snake case.
"""
is_var_len_args = False
@ -2558,7 +2729,7 @@ class Func(Condition):
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
if not hasattr(cls, "_sql_names"):
if "_sql_names" not in cls.__dict__:
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@ -2658,6 +2829,10 @@ class Cast(Func):
def to(self):
return self.args["to"]
@property
def output_name(self):
return self.name
class Collate(Binary):
pass
@ -2956,6 +3131,14 @@ class Pow(Func):
_sql_names = ["POWER", "POW"]
class PercentileCont(AggFunc):
pass
class PercentileDisc(AggFunc):
pass
class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
@ -3213,12 +3396,13 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
# Helpers
def maybe_parse(
sql_or_expression,
sql_or_expression: str | Expression,
*,
into=None,
dialect=None,
prefix=None,
into: t.Optional[IntoType] = None,
dialect: t.Optional[str] = None,
prefix: t.Optional[str] = None,
**opts,
) -> Expression:
"""Gracefully handle a possible string or expression.
@ -3230,11 +3414,11 @@ def maybe_parse(
(IDENTIFIER this: x, quoted: False)
Args:
sql_or_expression (str | Expression): the SQL code string or an expression
into (Expression): the SQLGlot Expression to parse into
dialect (str): the dialect used to parse the input expressions (in the case that an
sql_or_expression: the SQL code string or an expression
into: the SQLGlot Expression to parse into
dialect: the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
prefix (str): a string to prefix the sql with before it gets parsed
prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space)
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@ -3993,7 +4177,7 @@ def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
table (exp.Table | str): Table expression node or string.
table (exp.Table | str): table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
@ -4001,7 +4185,7 @@ def table_name(table) -> str:
'a.b.c'
Returns:
str: the table name
The table name.
"""
table = maybe_parse(table, into=Table)
@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping.
Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced
mapping (Dict[str, str]): Mapping of table names
expression (sqlglot.Expression): expression node to be transformed and replaced.
mapping (Dict[str, str]): mapping of table names.
Examples:
>>> from sqlglot import exp, parse_one
@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping):
'SELECT * FROM c'
Returns:
The mapped expression
The mapped expression.
"""
def _replace_tables(node):
@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs):
"""Replace placeholders in an expression.
Args:
expression (sqlglot.Expression): Expression node to be transformed and replaced
args: Positional names that will substitute unnamed placeholders in the given order
kwargs: Keyword arguments that will substitute named placeholders
expression (sqlglot.Expression): expression node to be transformed and replaced.
args: positional names that will substitute unnamed placeholders in the given order.
kwargs: keyword arguments that will substitute named placeholders.
Examples:
>>> from sqlglot import exp, parse_one
@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs):
'SELECT * FROM foo WHERE a = b'
Returns:
The mapped expression
The mapped expression.
"""
def _replace_placeholders(node, args, **kwargs):
@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
"""Transforms an expression by expanding all referenced sources into subqueries.
Examples:
>>> from sqlglot import parse_one
>>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
Args:
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
copy: Whether or not to copy the expression during transformation. Defaults to True.
Returns:
The transformed expression.
"""
def _expand(node: Expression):
if isinstance(node, Table):
name = table_name(node)
source = sources.get(name)
if source:
subquery = source.subquery(node.alias or name)
subquery.comments = [f"source: {name}"]
return subquery
return node
return expression.transform(_expand, copy=copy)
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
"""
Returns a Func expression.
Examples:
>>> func("abs", 5).sql()
'ABS(5)'
>>> func("cast", this=5, to=DataType.build("DOUBLE")).sql()
'CAST(5 AS DOUBLE)'
Args:
name: the name of the function to build.
args: the args used to instantiate the function of interest.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
Note:
The arguments `args` and `kwargs` are mutually exclusive.
Returns:
An instance of the function of interest, or an anonymous function, if `name` doesn't
correspond to an existing `sqlglot.expressions.Func` class.
"""
if args and kwargs:
raise ValueError("Can't use both args and kwargs to instantiate a function.")
from sqlglot.dialects.dialect import Dialect
args = tuple(convert(arg) for arg in args)
kwargs = {key: convert(value) for key, value in kwargs.items()}
parser = Dialect.get_or_raise(dialect)().parser()
from_args_list = parser.FUNCTIONS.get(name.upper())
if from_args_list:
function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore
else:
kwargs = kwargs or {"expressions": args}
function = Anonymous(this=name, **kwargs)
for error_message in function.error_messages(args):
raise ValueError(error_message)
return function
def true():
"""
Returns a true Boolean expression.
"""
return Boolean(this=True)
def false():
"""
Returns a false Boolean expression.
"""
return Boolean(this=False)
def null():
"""
Returns a Null expression.
"""
return Null()