1
0
Fork 0

Merging upstream version 6.2.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:40:43 +01:00
parent 0f5b9ddee1
commit 66e2d714bf
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
49 changed files with 1741 additions and 566 deletions

View file

@ -20,7 +20,7 @@ from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "6.2.1"
__version__ = "6.2.6"
pretty = False

View file

@ -33,6 +33,49 @@ def _date_add_sql(data_type, kind):
return func
def _subquery_to_unnest_if_values(self, expression):
if not isinstance(expression.this, exp.Values):
return self.subquery_sql(expression)
rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.this.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
return self.unnest_sql(unnest_exp)
def _returnsproperty_sql(self, expression):
value = expression.args.get("value")
if isinstance(value, exp.Schema):
value = f"{value.this} <{self.expressions(value)}>"
else:
value = self.sql(value)
return f"RETURNS {value}"
def _create_sql(self, expression):
kind = expression.args.get("kind")
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
expression.set("kind", "TABLE FUNCTION")
if isinstance(
expression.expression,
(
exp.Subquery,
exp.Literal,
),
):
expression.set("expression", expression.expression.this)
return self.create_sql(expression)
return self.create_sql(expression)
class BigQuery(Dialect):
unnest_column_only = True
@ -77,8 +120,14 @@ class BigQuery(Dialect):
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
class Generator(Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -91,6 +140,9 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.VariancePop: rename_func("VAR_POP"),
exp.Subquery: _subquery_to_unnest_if_values,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
}
TYPE_MAPPING = {

View file

@ -245,6 +245,11 @@ def no_tablesample_sql(self, expression):
return self.sql(expression.this)
def no_pivot_sql(self, expression):
self.unsupported("PIVOT unsupported")
return self.sql(expression)
def no_trycast_sql(self, expression):
return self.cast_sql(expression)
@ -282,3 +287,30 @@ def format_time_lambda(exp_class, dialect, default=None):
)
return _format_time
def create_with_partitions_sql(self, expression):
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
columns are removed from the create statement.
"""
has_schema = isinstance(expression.this, exp.Schema)
is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW")
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
value = prop and prop.args.get("value")
if prop and not isinstance(value, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in value.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
expression.set("this", schema)
return self.create_sql(expression)

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
no_pivot_sql,
no_safe_divide_sql,
no_tablesample_sql,
rename_func,
@ -122,6 +123,7 @@ class DuckDB(Dialect):
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -2,6 +2,7 @@ from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
create_with_partitions_sql,
format_time_lambda,
if_sql,
no_ilike_sql,
@ -53,7 +54,7 @@ def _array_sort(self, expression):
def _property_sql(self, expression):
key = expression.name
value = self.sql(expression, "value")
return f"'{key}' = {value}"
return f"'{key}'={value}"
def _str_to_unix(self, expression):
@ -218,15 +219,6 @@ class Hive(Dialect):
}
class Generator(Generator):
ROOT_PROPERTIES = [
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
]
WITH_PROPERTIES = [exp.AnonymousProperty]
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@ -255,13 +247,13 @@ class Hive(Dialect):
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: _map_sql,
HiveMap: _map_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}",
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
@ -282,6 +274,17 @@ class Hive(Dialect):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
}
WITH_PROPERTIES = {exp.AnonymousProperty}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
}
def with_properties(self, properties):

View file

@ -172,6 +172,11 @@ class MySQL(Dialect):
),
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
class Generator(Generator):
NULL_ORDERING_SUPPORTED = False
@ -190,3 +195,13 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
}
ROOT_PROPERTIES = {
exp.EngineProperty,
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {}

View file

@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
@ -158,7 +159,6 @@ class Postgres(Dialect):
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"IDENTITY": TokenType.IDENTITY,
"FOR": TokenType.FOR,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
"BIGSERIAL": TokenType.BIGSERIAL,
@ -204,6 +204,7 @@ class Postgres(Dialect):
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.Lateral: _lateral_sql,
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",

View file

@ -146,13 +146,16 @@ class Presto(Dialect):
STRUCT_DELIMITER = ("(", ")")
WITH_PROPERTIES = [
ROOT_PROPERTIES = {
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {
exp.PartitionedByProperty,
exp.FileFormatProperty,
exp.SchemaCommentProperty,
exp.AnonymousProperty,
exp.TableFormatProperty,
]
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
@ -184,13 +187,11 @@ class Presto(Dialect):
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.FileFormatProperty: lambda self, e: self.property_sql(e),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,

View file

@ -1,5 +1,10 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
inline_array_sql,
rename_func,
)
from sqlglot.expressions import Literal
from sqlglot.generator import Generator
from sqlglot.helper import list_get
@ -104,6 +109,8 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
}
FUNCTION_PARSERS = {
@ -111,6 +118,11 @@ class Snowflake(Dialect):
"DATE_PART": lambda self: self._parse_extract(),
}
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
TokenType.RLIKE,
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@ -120,6 +132,11 @@ class Snowflake(Dialect):
),
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
@ -137,6 +154,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
class Generator(Generator):
@ -145,6 +163,8 @@ class Snowflake(Dialect):
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.Array: inline_array_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
}
TYPE_MAPPING = {
@ -152,6 +172,13 @@ class Snowflake(Dialect):
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
}
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -1,5 +1,9 @@
from sqlglot import exp
from sqlglot.dialects.dialect import no_ilike_sql, rename_func
from sqlglot.dialects.dialect import (
create_with_partitions_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.dialects.hive import Hive, HiveMap
from sqlglot.helper import list_get
@ -10,7 +14,7 @@ def _create_sql(self, e):
if kind.upper() == "TABLE" and temporary is True:
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return self.create_sql(e)
return create_with_partitions_sql(self, e)
def _map_sql(self, expression):
@ -73,6 +77,7 @@ class Spark(Hive):
}
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",

View file

@ -1,4 +1,5 @@
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.mysql import MySQL
@ -10,3 +11,12 @@ class StarRocks(MySQL):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}

View file

@ -1,6 +1,7 @@
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
@ -17,6 +18,7 @@ class TSQL(Dialect):
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
"SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
@ -24,15 +26,24 @@ class TSQL(Dialect):
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
"ROWVERSION": TokenType.ROWVERSION,
"SQL_VARIANT": TokenType.SQL_VARIANT,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"XML": TokenType.XML,
"SQL_VARIANT": TokenType.VARIANT,
}
class Parser(Parser):
def _parse_convert(self):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_field()
return self.expression(exp.Cast, this=this, to=to)
class Generator(Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}

View file

@ -3,17 +3,11 @@ import time
from sqlglot import parse_one
from sqlglot.executor.python import PythonExecutor
from sqlglot.optimizer import RULES, optimize
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
logger = logging.getLogger("sqlglot")
OPTIMIZER_RULES = list(RULES)
# The executor needs isolated table selects
OPTIMIZER_RULES.remove(merge_derived_tables)
def execute(sql, schema, read=None):
"""
@ -34,7 +28,7 @@ def execute(sql, schema, read=None):
"""
expression = parse_one(sql, read=read)
now = time.time()
expression = optimize(expression, schema, rules=OPTIMIZER_RULES)
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression)

View file

@ -1,13 +1,17 @@
import inspect
import numbers
import re
import sys
from collections import deque
from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_list,
list_get,
subclasses,
)
class _Expression(type):
@ -31,12 +35,13 @@ class Expression(metaclass=_Expression):
key = None
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key")
__slots__ = ("args", "parent", "arg_key", "type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@ -384,7 +389,7 @@ class Expression(metaclass=_Expression):
'SELECT y FROM tbl'
Args:
expression (Expression): new node
expression (Expression|None): new node
Returns :
the new expression or expressions
@ -398,6 +403,12 @@ class Expression(metaclass=_Expression):
replace_children(parent, lambda child: expression if child is self else child)
return expression
def pop(self):
"""
Remove this expression from its AST.
"""
self.replace(None)
def assert_is(self, type_):
"""
Assert that this `Expression` is an instance of `type_`.
@ -527,9 +538,18 @@ class Create(Expression):
"temporary": False,
"replace": False,
"unique": False,
"materialized": False,
}
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
class UserDefinedFunctionKwarg(Expression):
arg_types = {"this": True, "kind": True, "default": False}
class CharacterSet(Expression):
arg_types = {"this": True, "default": False}
@ -887,6 +907,14 @@ class AnonymousProperty(Property):
pass
class ReturnsProperty(Property):
arg_types = {"this": True, "value": True, "is_table": False}
class LanguageProperty(Property):
pass
class Properties(Expression):
arg_types = {"expressions": True}
@ -907,25 +935,9 @@ class Properties(Expression):
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
expressions.append(property_cls(this=Literal.string(key), value=convert(value)))
return cls(expressions=expressions)
@staticmethod
def _convert_value(value):
if value is None:
return NULL
if isinstance(value, Expression):
return value
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, list):
return Tuple(expressions=[Properties._convert_value(v) for v in value])
raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
class Qualify(Expression):
pass
@ -1030,6 +1042,7 @@ class Subqueryable:
QUERY_MODIFIERS = {
"laterals": False,
"joins": False,
"pivots": False,
"where": False,
"group": False,
"having": False,
@ -1051,6 +1064,7 @@ class Table(Expression):
"catalog": False,
"laterals": False,
"joins": False,
"pivots": False,
}
@ -1643,6 +1657,16 @@ class TableSample(Expression):
"percent": False,
"rows": False,
"size": False,
"seed": False,
}
class Pivot(Expression):
arg_types = {
"this": False,
"expressions": True,
"field": True,
"unpivot": True,
}
@ -1741,7 +1765,8 @@ class DataType(Expression):
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
VARIANT = auto()
OBJECT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@ -2124,6 +2149,7 @@ class TryCast(Cast):
class Ceil(Func):
arg_types = {"this": True, "decimals": False}
_sql_names = ["CEIL", "CEILING"]
@ -2254,7 +2280,7 @@ class Explode(Func):
class Floor(Func):
pass
arg_types = {"this": True, "decimals": False}
class Greatest(Func):
@ -2371,7 +2397,7 @@ class Reduce(Func):
class RegexpLike(Func):
arg_types = {"this": True, "expression": True}
arg_types = {"this": True, "expression": True, "flag": False}
class RegexpSplit(Func):
@ -2540,6 +2566,8 @@ def _norm_args(expression):
for k, arg in expression.args.items():
if isinstance(arg, list):
arg = [_norm_arg(a) for a in arg]
if not arg:
arg = None
else:
arg = _norm_arg(arg)
@ -2553,17 +2581,7 @@ def _norm_arg(arg):
return arg.lower() if isinstance(arg, str) else arg
def _all_functions():
return [
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
ALL_FUNCTIONS = _all_functions()
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
def maybe_parse(
@ -2793,6 +2811,37 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
def update(table, properties, where=None, from_=None, dialect=None, **opts):
"""
Creates an update statement.
Example:
>>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz", where="id > 1").sql()
"UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1"
Args:
*properties (Dict[str, Any]): dictionary of properties to set which are
auto converted to sql objects eg None -> NULL
where (str): sql conditional parsed into a WHERE statement
from_ (str): sql statement parsed into a FROM statement
dialect (str): the dialect used to parse the input expressions.
**opts: other options to use to parse the input expressions.
Returns:
Update: the syntax tree for the UPDATE statement.
"""
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
[EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
if where:
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
return update
def condition(expression, dialect=None, **opts):
"""
Initialize a logical condition expression.
@ -2980,12 +3029,13 @@ def column(col, table=None, quoted=None):
def table_(table, db=None, catalog=None, quoted=None):
"""
Build a Table.
"""Build a Table.
Args:
table (str or Expression): column name
db (str or Expression): db name
catalog (str or Expression): catalog name
Returns:
Table: table instance
"""
@ -2996,6 +3046,39 @@ def table_(table, db=None, catalog=None, quoted=None):
)
def convert(value):
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
value (Any): a python object
Returns:
Expression: the equivalent expression object
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value.keys()],
values=[convert(v) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")
def replace_children(expression, fun):
"""
Replace children of an expression with the result of a lambda fun(child) -> exp.

View file

@ -46,18 +46,12 @@ class Generator:
"""
TRANSFORMS = {
exp.AnonymousProperty: lambda self, e: self.property_sql(e),
exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}",
exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}",
exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}",
exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}",
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
}
@ -72,19 +66,17 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
ROOT_PROPERTIES = [
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.EngineProperty,
exp.SchemaCommentProperty,
]
WITH_PROPERTIES = [
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
}
WITH_PROPERTIES = {
exp.AnonymousProperty,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
]
}
__slots__ = (
"time_mapping",
@ -188,6 +180,7 @@ class Generator:
return sql
def unsupported(self, message):
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@ -261,6 +254,9 @@ class Generator:
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
if isinstance(expression, exp.Property):
return self.property_sql(expression)
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
def annotation_sql(self, expression):
@ -352,9 +348,12 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}"
expression_sql = (
f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
)
return self.prepend_ctes(expression, expression_sql)
def prepend_ctes(self, expression, sql):
@ -461,10 +460,10 @@ class Generator:
for p in expression.expressions:
p_class = p.__class__
if p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
if p_class in self.WITH_PROPERTIES:
with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
exp.Properties(expressions=with_properties)
@ -496,9 +495,12 @@ class Generator:
)
def property_sql(self, expression):
key = expression.name
if isinstance(expression.this, exp.Literal):
key = expression.this.this
else:
key = expression.name
value = self.sql(expression, "value")
return f"{key} = {value}"
return f"{key}={value}"
def insert_sql(self, expression):
kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
@ -535,7 +537,8 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
return f"{table}{laterals}{joins}"
pivots = self.expressions(expression, key="pivots", sep="")
return f"{table}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression):
if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
@ -556,7 +559,17 @@ class Generator:
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}"
seed = self.sql(expression, "seed")
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
def pivot_sql(self, expression):
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
def tuple_sql(self, expression):
return f"({self.expressions(expression, flat=True)})"
@ -681,6 +694,7 @@ class Generator:
def ordered_sql(self, expression):
desc = expression.args.get("desc")
asc = not desc
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.null_ordering == "nulls_are_large"
@ -760,6 +774,7 @@ class Generator:
return self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
)
@ -1129,6 +1144,9 @@ class Generator:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
return f"{expression.name} {self.sql(expression, 'value')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
@ -1136,3 +1154,13 @@ class Generator:
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression):
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
def userdefinedfunctionkwarg_sql(self, expression):
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"

View file

@ -1,5 +1,7 @@
import inspect
import logging
import re
import sys
from contextlib import contextmanager
from enum import Enum
@ -29,6 +31,26 @@ def csv(*args, sep=", "):
return sep.join(arg for arg in args if arg)
def subclasses(module_name, classes, exclude=()):
"""
Returns a list of all subclasses for a specified class set, posibly excluding some of them.
Args:
module_name (str): The name of the module to search for subclasses in.
classes (type|tuple[type]): Class(es) we want to find the subclasses of.
exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
Returns:
A list of all the target subclasses.
"""
return [
obj
for _, obj in inspect.getmembers(
sys.modules[module_name],
lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude,
)
]
def apply_index_offset(expressions, offset):
if not offset or len(expressions) != 1:
return expressions
@ -100,7 +122,7 @@ def csv_reader(table):
Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
Args:
expression (Expression): An anonymous function READ_CSV
table (exp.Table): A table expression with an anonymous function READ_CSV in it
Returns:
A python csv reader.
@ -121,3 +143,22 @@ def csv_reader(table):
yield csv_.reader(file, delimiter=delimiter)
finally:
file.close()
def find_new_name(taken, base):
"""
Searches for a new name.
Args:
taken (Sequence[str]): set of taken names
base (str): base name to alter
"""
if base not in taken:
return base
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new

View file

@ -0,0 +1,162 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
(TODO -- replace this with a better example after adding some functionality)
Example:
>>> import sqlglot
>>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
>>> annotated_expression.type
<Type.DOUBLE: 'DOUBLE'>
Args:
expression (sqlglot.Expression): Expression to annotate.
schema (dict|sqlglot.optimizer.Schema): Database schema.
annotators (dict): Maps expression type to corresponding annotation function.
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
"""
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
class TypeAnnotator:
ANNOTATORS = {
**{
expr_type: lambda self, expr: self._annotate_unary(expr)
for expr_type in subclasses(exp.__name__, exp.Unary)
},
**{
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_cast(expr),
exp.DataType: lambda self, expr: self._annotate_data_type(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
}
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
COERCES_TO = {
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.SMALLINT: {
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.TINYINT: {
exp.DataType.Type.SMALLINT,
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATE: {
exp.DataType.Type.DATETIME,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
},
}
def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
annotator = self.annotators.get(expression.__class__)
return annotator(self, expression) if annotator else self._annotate_args(expression)
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
self.annotate(v)
return expression
def _annotate_cast(self, expression):
expression.type = expression.args["to"].this
return self._annotate_args(expression)
def _annotate_data_type(self, expression):
expression.type = expression.this
return self._annotate_args(expression)
def _maybe_coerce(self, type1, type2):
return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
if isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
return expression
def _annotate_unary(self, expression):
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = expression.this.type
return expression
def _annotate_literal(self, expression):
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
expression.type = exp.DataType.Type.INT
else:
expression.type = exp.DataType.Type.DOUBLE
return expression
def _annotate_boolean(self, expression):
expression.type = exp.DataType.Type.BOOLEAN
return expression

View file

@ -1,48 +1,144 @@
import itertools
from sqlglot import alias, exp, select, table
from sqlglot.optimizer.scope import traverse_scope
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
"""
Rewrite duplicate subqueries from sqlglot AST.
Rewrite subqueries as CTES, deduplicating if possible.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y")
>>> eliminate_subqueries(expression).sql()
'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
Args:
expression (sqlglot.Expression): expression to qualify
schema (dict|sqlglot.optimizer.Schema): Database schema
expression (sqlglot.Expression): expression
Returns:
sqlglot.Expression: qualified expression
sqlglot.Expression: expression
"""
if isinstance(expression, exp.Subquery):
# It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1
eliminate_subqueries(expression.this)
return expression
expression = simplify(expression)
queries = {}
root = build_scope(expression)
for scope in traverse_scope(expression):
query = scope.expression
queries[query] = queries.get(query, []) + [query]
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
sequence = itertools.count()
# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
taken[scope.expression.parent.alias] = scope
for query, duplicates in queries.items():
if len(duplicates) == 1:
continue
# All table names are taken
for scope in root.traverse():
taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
alias_ = f"_e_{next(sequence)}"
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
for dup in duplicates:
parent = dup.parent
if isinstance(parent, exp.Subquery):
parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
elif isinstance(parent, exp.Union):
dup.replace(select("*").from_(alias_))
with_ = root.expression.args.get("with")
if with_:
for cte in with_.expressions:
existing_ctes[cte.this] = cte.alias
new_ctes = []
expression.with_(alias_, as_=query, copy=False)
# We're adding more CTEs, but we want to maintain the DAG order.
# Derived tables within an existing CTE need to come before the existing CTE.
for cte_scope in root.cte_scopes:
# Append all the new CTEs from this existing CTE
for scope in cte_scope.traverse():
new_cte = _eliminate(scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
# Append the existing CTE itself
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
new_ctes.append(new_cte)
if new_ctes:
expression.set("with", exp.With(expressions=new_ctes))
return expression
def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
if scope.is_derived_table and not isinstance(scope.expression, (exp.Unnest, exp.Lateral)):
return _eliminate_derived_table(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
taken[alias] = scope
# Try to maintain the selections
expressions = scope.expression.args.get("expressions")
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions
if e.alias_or_name
]
# If not all selections have an alias, just select *
if len(selects) != len(expressions):
selects = ["*"]
scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias)))
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(alias)),
)
def _eliminate_derived_table(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
parent = scope.expression.parent
name = alias = parent.alias
if not alias:
name = alias = find_new_name(taken=taken, base="cte")
if duplicate_cte_alias:
name = duplicate_cte_alias
elif taken.get(alias):
name = find_new_name(taken=taken, base=alias)
taken[name] = scope
table = exp.alias_(exp.table_(name), alias=alias)
parent.replace(table)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = name
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(name)),
)

View file

@ -1,45 +1,39 @@
from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.optimizer.simplify import simplify
def merge_derived_tables(expression):
def merge_subqueries(expression, leave_tables_isolated=False):
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
This also merges CTEs if they are selected from only once.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x)")
>>> merge_derived_tables(expression).sql()
'SELECT x.a FROM x'
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Args:
expression (sqlglot.Expression): expression to optimize
leave_tables_isolated (bool):
Returns:
sqlglot.Expression: optimized expression
"""
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and _mergeable(inner_select)
):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
merge_ctes(expression, leave_tables_isolated)
merge_derived_tables(expression, leave_tables_isolated)
return expression
@ -53,20 +47,81 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - {
}
def _mergeable(inner_select):
def merge_ctes(expression, leave_tables_isolated=False):
scopes = traverse_scope(expression)
# All places where we select from CTEs.
# We key on the CTE scope so we can detect CTES that are selected from multiple times.
cte_selections = defaultdict(list)
for outer_scope in scopes:
for table, inner_scope in outer_scope.selected_sources.values():
if isinstance(inner_scope, Scope) and inner_scope.is_cte:
cte_selections[id(inner_scope)].append(
(
outer_scope,
inner_scope,
table,
)
)
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
inner_select = inner_scope.expression.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
from_or_join = table.find_ancestor(exp.From, exp.Join)
node_to_replace = table
if isinstance(node_to_replace.parent, exp.Alias):
node_to_replace = node_to_replace.parent
alias = node_to_replace.alias
else:
alias = table.name
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, node_to_replace, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_pop_cte(inner_scope)
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
inner_select = subquery.unnest()
if _mergeable(outer_scope, inner_select, leave_tables_isolated):
alias = subquery.alias_or_name
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
inner_scope = outer_scope.sources[alias]
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_joins(outer_scope, inner_scope, from_or_join)
_merge_expressions(outer_scope, inner_scope, alias)
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
def _mergeable(outer_scope, inner_select, leave_tables_isolated):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_select (exp.Select)
leave_tables_isolated (bool)
Returns:
bool: True if can be merged
"""
return (
isinstance(inner_select, exp.Select)
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
)
@ -84,7 +139,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
conflicts = conflicts - {alias}
for conflict in conflicts:
new_name = _find_new_name(taken, conflict)
new_name = find_new_name(taken, conflict)
source, _ = inner_scope.selected_sources[conflict]
new_alias = exp.to_identifier(new_name)
@ -102,34 +157,19 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)
def _find_new_name(taken, base):
"""
Searches for a new source name.
Args:
taken (set[str]): set of taken names
base (str): base name to alter
"""
i = 2
new = f"{base}_{i}"
while new in taken:
i += 1
new = f"{base}_{i}"
return new
def _merge_from(outer_scope, inner_scope, subquery):
def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
subquery (exp.Subquery)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
subquery.replace(new_subquery)
outer_scope.remove_source(subquery.alias_or_name)
node_to_replace.replace(new_subquery)
outer_scope.remove_source(alias)
outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
@ -176,7 +216,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
# Collect all columns that for the alias of the inner query
# Collect all columns that reference the alias of the inner query
outer_columns = defaultdict(list)
for column in outer_scope.columns:
if column.table == alias:
@ -205,7 +245,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this:
return
if isinstance(from_or_join, exp.Join) and from_or_join.side:
if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
@ -230,3 +270,18 @@ def _merge_order(outer_scope, inner_scope):
return
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))
def _pop_cte(inner_scope):
"""
Remove CTE from the AST.
Args:
inner_scope (sqlglot.optimizer.scope.Scope)
"""
cte = inner_scope.expression.parent
with_ = cte.parent
if len(with_.expressions) == 1:
with_.pop()
else:
cte.pop()

View file

@ -1,7 +1,7 @@
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.merge_derived_tables import merge_derived_tables
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
@ -22,7 +22,7 @@ RULES = (
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
merge_derived_tables,
merge_subqueries,
quote_identities,
)

View file

@ -37,7 +37,7 @@ def pushdown_projections(expression):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
left, right = scope.union
left, right = scope.union_scopes
referenced_columns[left] = parent_selections
referenced_columns[right] = parent_selections

View file

@ -69,7 +69,7 @@ def ensure_schema(schema):
def fs_get(table):
name = table.this.name.upper()
name = table.this.name
if name.upper() == "READ_CSV":
with csv_reader(table) as reader:

View file

@ -1,3 +1,4 @@
import itertools
from copy import copy
from enum import Enum, auto
@ -32,10 +33,11 @@ class Scope:
The inner query would have `["col1", "col2"]` for its `outer_column_list`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries.
This does not include derived tables or CTEs.
union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
a tuple of the left and right child scopes.
subquery_scopes (list[Scope]): List of all child scopes for subqueries
cte_scopes = (list[Scope]) List of all child scopes for CTEs
derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
def __init__(
@ -52,7 +54,9 @@ class Scope:
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.union = None
self.derived_table_scopes = []
self.cte_scopes = []
self.union_scopes = []
self.clear_cache()
def clear_cache(self):
@ -197,11 +201,16 @@ class Scope:
named_outputs = {e.alias_or_name for e in self.expression.expressions}
self._columns = [
c
for c in columns + external_columns
if not (c.find_ancestor(exp.Qualify, exp.Order) and not c.table and c.name in named_outputs)
]
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
if (
not ancestor
or column.table
or (column.name not in named_outputs and not isinstance(ancestor, exp.Hint))
):
self._columns.append(column)
return self._columns
@property
@ -283,6 +292,26 @@ class Scope:
"""Determine if this scope is a subquery"""
return self.scope_type == ScopeType.SUBQUERY
@property
def is_derived_table(self):
"""Determine if this scope is a derived table"""
return self.scope_type == ScopeType.DERIVED_TABLE
@property
def is_union(self):
"""Determine if this scope is a union"""
return self.scope_type == ScopeType.UNION
@property
def is_cte(self):
"""Determine if this scope is a common table expression"""
return self.scope_type == ScopeType.CTE
@property
def is_root(self):
"""Determine if this is the root scope"""
return self.scope_type == ScopeType.ROOT
@property
def is_unnest(self):
"""Determine if this scope is an unnest"""
@ -308,6 +337,22 @@ class Scope:
self.sources.pop(name, None)
self.clear_cache()
def __repr__(self):
return f"Scope<{self.expression.sql()}>"
def traverse(self):
"""
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.subquery_scopes, self.derived_table_scopes
):
yield from child_scope.traverse()
yield self
def traverse_scope(expression):
"""
@ -337,6 +382,18 @@ def traverse_scope(expression):
return list(_traverse_scope(Scope(expression)))
def build_scope(expression):
"""
Build a scope tree.
Args:
expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
@ -370,13 +427,14 @@ def _traverse_union(scope):
for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
yield right
scope.union = (left, right)
scope.union_scopes = [left, right]
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
for derived_table in derived_tables:
top = None
for child_scope in _traverse_scope(
scope.branch(
derived_table if isinstance(derived_table, (exp.Unnest, exp.Lateral)) else derived_table.this,
@ -386,11 +444,16 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
)
):
yield child_scope
top = child_scope
# Tables without aliases will be set as ""
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
sources[derived_table.alias] = child_scope
if scope_type == ScopeType.CTE:
scope.cte_scopes.append(top)
else:
scope.derived_table_scopes.append(top)
scope.sources.update(sources)
@ -407,8 +470,6 @@ def _add_table_sources(scope):
if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
scope.sources[source_name] = scope.sources[table_name]
elif source_name in scope.sources:
raise OptimizeError(f"Duplicate table name: {source_name}")
else:
sources[source_name] = table

View file

@ -99,7 +99,8 @@ class Parser:
TokenType.SMALLMONEY,
TokenType.ROWVERSION,
TokenType.IMAGE,
TokenType.SQL_VARIANT,
TokenType.VARIANT,
TokenType.OBJECT,
*NESTED_TYPE_TOKENS,
}
@ -131,7 +132,6 @@ class Parser:
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
TokenType.FOR,
TokenType.FORMAT,
TokenType.FUNCTION,
TokenType.GENERATED,
@ -141,20 +141,26 @@ class Parser:
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
TokenType.LEADING,
TokenType.LOCATION,
TokenType.MATERIALIZED,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
TokenType.OPTIMIZE,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROWS,
TokenType.SCHEMA_COMMENT,
TokenType.SEED,
TokenType.SET,
TokenType.SHOW,
TokenType.STORED,
@ -167,6 +173,7 @@ class Parser:
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
@ -303,6 +310,8 @@ class Parser:
exp.Condition: lambda self: self._parse_conjunction(),
exp.Expression: lambda self: self._parse_statement(),
exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Ordered: lambda self: self._parse_ordered(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}
@ -355,23 +364,21 @@ class Parser:
PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.COLLATE: lambda self: self._parse_collate(),
TokenType.ENGINE: lambda self: self._parse_engine(),
TokenType.FORMAT: lambda self: self._parse_format(),
TokenType.LOCATION: lambda self: self.expression(
exp.LocationProperty,
this=exp.Literal.string("LOCATION"),
value=self._parse_string(),
),
TokenType.PARTITIONED_BY: lambda self: self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema(),
),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(),
TokenType.USING: lambda self: self._parse_table_format(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
}
CONSTRAINT_PARSERS = {
@ -388,6 +395,7 @@ class Parser:
FUNCTION_PARSERS = {
"CONVERT": lambda self: self._parse_convert(),
"EXTRACT": lambda self: self._parse_extract(),
"POSITION": lambda self: self._parse_position(),
"SUBSTRING": lambda self: self._parse_substring(),
"TRIM": lambda self: self._parse_trim(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
@ -628,6 +636,10 @@ class Parser:
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
create_token = self._match_set(self.CREATABLES) and self._prev
@ -640,14 +652,15 @@ class Parser:
properties = None
if create_token.token_type == TokenType.FUNCTION:
this = self._parse_var()
this = self._parse_user_defined_function()
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
expression = self._parse_string()
expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
this = self._parse_table(schema=True)
properties = self._parse_properties(this if isinstance(this, exp.Schema) else None)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
expression = self._parse_select(nested=True)
@ -661,9 +674,10 @@ class Parser:
temporary=temporary,
replace=replace,
unique=unique,
materialized=materialized,
)
def _parse_property(self, schema):
def _parse_property(self):
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
@ -673,31 +687,27 @@ class Parser:
key = self._parse_var().this
self._match(TokenType.EQ)
if key.upper() == "PARTITIONED_BY":
expression = exp.PartitionedByProperty
value = self._parse_schema() or self._parse_bracket(self._parse_field())
if schema and not isinstance(value, exp.Schema):
columns = {v.name.upper() for v in value.expressions}
partitions = [
expression for expression in schema.expressions if expression.this.name.upper() in columns
]
schema.set(
"expressions",
[e for e in schema.expressions if e not in partitions],
)
value = self.expression(exp.Schema, expressions=partitions)
else:
value = self._parse_column()
expression = exp.AnonymousProperty
return self.expression(
expression,
exp.AnonymousProperty,
this=exp.Literal.string(key),
value=value,
value=self._parse_column(),
)
return None
def _parse_property_assignment(self, exp_class):
prop = self._prev.text
self._match(TokenType.EQ)
return self.expression(exp_class, this=prop, value=self._parse_var_or_string())
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=exp.Literal.string("PARTITIONED_BY"),
value=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
def _parse_stored(self):
self._match(TokenType.ALIAS)
self._match(TokenType.EQ)
@ -707,22 +717,6 @@ class Parser:
value=exp.Literal.string(self._parse_var().name),
)
def _parse_format(self):
self._match(TokenType.EQ)
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
value=self._parse_string() or self._parse_var(),
)
def _parse_engine(self):
self._match(TokenType.EQ)
return self.expression(
exp.EngineProperty,
this=exp.Literal.string("ENGINE"),
value=self._parse_var_or_string(),
)
def _parse_auto_increment(self):
self._match(TokenType.EQ)
return self.expression(
@ -731,14 +725,6 @@ class Parser:
value=self._parse_var() or self._parse_number(),
)
def _parse_collate(self):
self._match(TokenType.EQ)
return self.expression(
exp.CollateProperty,
this=exp.Literal.string("COLLATE"),
value=self._parse_var_or_string(),
)
def _parse_schema_comment(self):
self._match(TokenType.EQ)
return self.expression(
@ -756,26 +742,34 @@ class Parser:
default=default,
)
def _parse_table_format(self):
self._match(TokenType.EQ)
def _parse_returns(self):
is_table = self._match(TokenType.TABLE)
if is_table:
if self._match(TokenType.LT):
value = self.expression(
exp.Schema, this="TABLE", expressions=self._parse_csv(self._parse_struct_kwargs)
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
value = self._parse_schema("TABLE")
else:
value = self._parse_types()
return self.expression(
exp.TableFormatProperty,
this=exp.Literal.string("TABLE_FORMAT"),
value=self._parse_var_or_string(),
exp.ReturnsProperty,
this=exp.Literal.string("RETURNS"),
value=value,
is_table=is_table,
)
def _parse_properties(self, schema=None):
"""
Schema is included since if the table schema is defined and we later get a partition by expression
then we will define those columns in the partition by section and not in with the rest of the
columns
"""
def _parse_properties(self):
properties = []
while True:
if self._match(TokenType.WITH):
self._match_l_paren()
properties.extend(self._parse_csv(lambda: self._parse_property(schema)))
properties.extend(self._parse_csv(lambda: self._parse_property()))
self._match_r_paren()
elif self._match(TokenType.PROPERTIES):
self._match_l_paren()
@ -790,7 +784,7 @@ class Parser:
)
self._match_r_paren()
else:
identified_property = self._parse_property(schema)
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
@ -1003,7 +997,7 @@ class Parser:
)
def _parse_subquery(self, this):
return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
return self.expression(exp.Subquery, this=this, pivots=self._parse_pivots(), alias=self._parse_table_alias())
def _parse_query_modifiers(self, this):
if not isinstance(this, self.MODIFIABLES):
@ -1134,14 +1128,18 @@ class Parser:
table = (not schema and self._parse_function()) or self._parse_id_var(False)
while self._match(TokenType.DOT):
catalog = db
db = table
table = self._parse_id_var()
if catalog:
# This allows nesting the table in arbitrarily many dot expressions if needed
table = self.expression(exp.Dot, this=table, expression=self._parse_id_var())
else:
catalog = db
db = table
table = self._parse_id_var()
if not table:
self.raise_error("Expected table name")
this = self.expression(exp.Table, this=table, db=db, catalog=catalog)
this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
if schema:
return self._parse_schema(this=this)
@ -1199,6 +1197,7 @@ class Parser:
percent = None
rows = None
size = None
seed = None
self._match_l_paren()
@ -1220,6 +1219,11 @@ class Parser:
self._match_r_paren()
if self._match(TokenType.SEED):
self._match_l_paren()
seed = self._parse_number()
self._match_r_paren()
return self.expression(
exp.TableSample,
method=method,
@ -1229,6 +1233,51 @@ class Parser:
percent=percent,
rows=rows,
size=size,
seed=seed,
)
def _parse_pivots(self):
return list(iter(self._parse_pivot, None))
def _parse_pivot(self):
index = self._index
if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True
else:
return None
expressions = []
field = None
if not self._match(TokenType.L_PAREN):
self._retreat(index)
return None
if unpivot:
expressions = self._parse_csv(self._parse_column)
else:
expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function()))
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
value = self._parse_column()
if not self._match(TokenType.IN):
self.raise_error("Expecting IN")
field = self._parse_in(value)
self._match_r_paren()
return self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
)
def _parse_where(self):
@ -1384,7 +1433,7 @@ class Parser:
this = self.expression(exp.In, this=this, unnest=unnest)
else:
self._match_l_paren()
expressions = self._parse_csv(lambda: self._parse_select() or self._parse_expression())
expressions = self._parse_csv(self._parse_select_or_expression)
if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
this = self.expression(exp.In, this=this, query=expressions[0])
@ -1577,6 +1626,9 @@ class Parser:
if self._match_set(self.PRIMARY_PARSERS):
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
query = self._parse_select()
@ -1647,6 +1699,23 @@ class Parser:
self._match_r_paren()
return self._parse_window(this)
def _parse_user_defined_function(self):
this = self._parse_var()
if not self._match(TokenType.L_PAREN):
return this
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
if not kind:
return this
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
def _parse_lambda(self):
index = self._index
@ -1672,9 +1741,10 @@ class Parser:
return self._parse_alias(self._parse_limit(self._parse_order(this)))
conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
return self.expression(
exp.Lambda,
this=self._parse_conjunction(),
this=conjunction,
expressions=expressions,
)
@ -1896,6 +1966,12 @@ class Parser:
to = None
return self.expression(exp.Cast, this=this, to=to)
def _parse_position(self):
substr = self._parse_bitwise()
if self._match(TokenType.IN):
string = self._parse_bitwise()
return self.expression(exp.StrPosition, this=string, substr=substr)
def _parse_substring(self):
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@ -2155,6 +2231,9 @@ class Parser:
self._match_r_paren()
return expressions
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
def _match(self, token_type):
if not self._curr:
return None
@ -2208,3 +2287,9 @@ class Parser:
elif isinstance(this, exp.Identifier):
this = self.expression(exp.Var, this=this.name)
return this
def _replace_lambda(self, node, lambda_variables):
if isinstance(node, exp.Column):
if node.name in lambda_variables:
return node.this
return node

View file

@ -94,7 +94,8 @@ class TokenType(AutoName):
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
SQL_VARIANT = auto()
VARIANT = auto()
OBJECT = auto()
# keywords
ADD_FILE = auto()
@ -177,6 +178,7 @@ class TokenType(AutoName):
IS = auto()
ISNULL = auto()
JOIN = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
LEADING = auto()
@ -185,6 +187,7 @@ class TokenType(AutoName):
LIMIT = auto()
LOCATION = auto()
MAP = auto()
MATERIALIZED = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
@ -208,6 +211,7 @@ class TokenType(AutoName):
PARTITION_BY = auto()
PARTITIONED_BY = auto()
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
@ -219,12 +223,14 @@ class TokenType(AutoName):
REPLACE = auto()
RESPECT_NULLS = auto()
REFERENCES = auto()
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
SCHEMA_COMMENT = auto()
SEED = auto()
SELECT = auto()
SEPARATOR = auto()
SET = auto()
@ -246,6 +252,7 @@ class TokenType(AutoName):
UNCACHE = auto()
UNION = auto()
UNNEST = auto()
UNPIVOT = auto()
UPDATE = auto()
USE = auto()
USING = auto()
@ -440,6 +447,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FULL": TokenType.FULL,
"FUNCTION": TokenType.FUNCTION,
"FOLLOWING": TokenType.FOLLOWING,
"FOR": TokenType.FOR,
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
@ -459,6 +467,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"LANGUAGE": TokenType.LANGUAGE,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@ -466,6 +475,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LIKE": TokenType.LIKE,
"LIMIT": TokenType.LIMIT,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
@ -473,6 +483,7 @@ class Tokenizer(metaclass=_Tokenizer):
"NULL": TokenType.NULL,
"NULLS FIRST": TokenType.NULLS_FIRST,
"NULLS LAST": TokenType.NULLS_LAST,
"OBJECT": TokenType.OBJECT,
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
@ -488,7 +499,9 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
"PARTITIONED BY": TokenType.PARTITIONED_BY,
"PARTITIONED_BY": TokenType.PARTITIONED_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"RANGE": TokenType.RANGE,
@ -497,11 +510,13 @@ class Tokenizer(metaclass=_Tokenizer):
"REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
"SEED": TokenType.SEED,
"SELECT": TokenType.SELECT,
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
@ -520,6 +535,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
"UNNEST": TokenType.UNNEST,
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
@ -577,6 +593,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DATETIME": TokenType.DATETIME,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
}
WHITE_SPACE = {

View file

@ -12,15 +12,20 @@ def unalias_group(expression):
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias)
e.alias: (i, e.this)
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
}
expression = expression.copy()
for col in expression.find_all(exp.Column):
alias_index = aliased_selects.get(col.name)
if not col.table and alias_index:
col.replace(exp.Literal.number(alias_index))
top_level_expression = None
for item, parent, _ in expression.walk(bfs=False):
top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression
if isinstance(item, exp.Column) and not item.table:
alias_index, col_expression = aliased_selects.get(item.name, (None, None))
if alias_index and top_level_expression != col_expression:
item.replace(exp.Literal.number(alias_index))
return expression