1
0
Fork 0

Merging upstream version 11.3.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:42:13 +01:00
parent f223c02081
commit 1c10961499
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
62 changed files with 26499 additions and 24781 deletions

View file

@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "11.2.3"
__version__ = "11.3.0"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -67,10 +67,10 @@ class Column:
return self.binary_op(exp.Mul, other)
def __truediv__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Div, other)
return self.binary_op(exp.FloatDiv, other)
def __div__(self, other: ColumnOrLiteral) -> Column:
return self.binary_op(exp.Div, other)
return self.binary_op(exp.FloatDiv, other)
def __neg__(self) -> Column:
return self.unary_op(exp.Neg)
@ -85,10 +85,10 @@ class Column:
return self.inverse_binary_op(exp.Mul, other)
def __rdiv__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Div, other)
return self.inverse_binary_op(exp.FloatDiv, other)
def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Div, other)
return self.inverse_binary_op(exp.FloatDiv, other)
def __rmod__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Mod, other)

View file

@ -260,7 +260,7 @@ class DataFrame:
@classmethod
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
expression = item.expression if isinstance(item, DataFrame) else item
return [Column(x) for x in expression.find(exp.Select).expressions]
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):

View file

@ -954,10 +954,12 @@ def array_join(
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
) -> Column:
if null_replacement is not None:
return Column.invoke_anonymous_function(
col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
)
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter)
)
def concat(*cols: ColumnOrName) -> Column:

View file

@ -213,7 +213,11 @@ class BigQuery(Dialect):
),
}
INTEGER_DIVISION = False
class Generator(generator.Generator):
INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore

View file

@ -56,6 +56,8 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
INTEGER_DIVISION = False
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
@ -94,6 +96,7 @@ class ClickHouse(Dialect):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
INTEGER_DIVISION = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore

View file

@ -360,10 +360,9 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
this = prop and prop.this
if prop and not isinstance(this, exp.Schema):
if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
columns = {v.name.upper() for v in this.expressions}
columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))

View file

@ -83,6 +83,7 @@ class DuckDB(Dialect):
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"CHARACTER VARYING": TokenType.VARCHAR,
"EXCLUDE": TokenType.EXCEPT,
}
class Parser(parser.Parser):
@ -173,3 +174,8 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}
STAR_MAPPING = {
**generator.Generator.STAR_MAPPING,
"except": "EXCLUDE",
}

View file

@ -256,7 +256,11 @@ class Hive(Dialect):
),
}
INTEGER_DIVISION = False
class Generator(generator.Generator):
INTEGER_DIVISION = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",

View file

@ -300,6 +300,8 @@ class MySQL(Dialect):
"READ ONLY",
}
INTEGER_DIVISION = False
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@ -432,6 +434,7 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore

View file

@ -82,8 +82,17 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
INTEGER_DIVISION = False
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
INTEGER_DIVISION = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -108,6 +117,8 @@ class Oracle(Dialect):
exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.Substring: rename_func("SUBSTR"),
@ -139,8 +150,9 @@ class Oracle(Dialect):
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
return super().table_sql(expression, sep=sep)
def column_sql(self, expression: exp.Column) -> str:
column = super().column_sql(expression)
return f"{column} (+)" if expression.args.get("join_mark") else column
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
@ -156,6 +168,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,

View file

@ -222,10 +222,8 @@ class Postgres(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
@ -260,10 +258,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
FACTOR = {
**parser.Parser.FACTOR, # type: ignore
TokenType.CARET: exp.Pow,
}
FACTOR = {**parser.Parser.FACTOR, TokenType.CARET: exp.Pow}
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@ -104,6 +106,20 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _div0_to_if(args):
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1))
return exp.If(this=cond, true=true, false=false)
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
def _zeroifnull_to_if(args):
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
@ -150,16 +166,20 @@ class Snowflake(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
"ZEROIFNULL": _zeroifnull_to_if,
}
FUNCTION_PARSERS = {
@ -193,6 +213,19 @@ class Snowflake(Dialect):
),
}
ALTER_PARSERS = {
**parser.Parser.ALTER_PARSERS, # type: ignore
"UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
"SET": lambda self: self._parse_alter_table_set_tag(),
}
INTEGER_DIVISION = False
def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
self._match_text_seq("TAG")
parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
@ -220,12 +253,14 @@ class Snowflake(Dialect):
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.DateAdd: rename_func("DATEADD"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@ -294,6 +329,10 @@ class Snowflake(Dialect):
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
def settag_sql(self, expression: exp.SetTag) -> str:
action = "UNSET" if expression.args.get("unset") else "SET"
return f"{action} TAG {self.expressions(expression)}"
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need

View file

@ -74,6 +74,7 @@ class Teradata(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
"RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
@ -105,6 +106,15 @@ class Teradata(Dialect):
},
)
def _parse_rangen(self):
this = self._parse_id_var()
self._match(TokenType.BETWEEN)
expressions = self._parse_csv(self._parse_conjunction)
each = self._match_text_seq("EACH") and self._parse_conjunction()
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -114,7 +124,6 @@ class Teradata(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
exp.VolatilityProperty: exp.Properties.Location.POST_CREATE,
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
@ -137,3 +146,11 @@ class Teradata(Dialect):
type_sql = super().datatype_sql(expression)
prefix_sql = expression.args.get("prefix")
return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql
def rangen_sql(self, expression: exp.RangeN) -> str:
this = self.sql(expression, "this")
expressions_sql = self.expressions(expression)
each_sql = self.sql(expression, "each")
each_sql = f" EACH {each_sql}" if each_sql else ""
return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"

View file

@ -35,6 +35,8 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
E = t.TypeVar("E", bound="Expression")
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@ -293,7 +295,7 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
def find(self, *expression_types, bfs=True):
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
the specified types.
@ -306,7 +308,7 @@ class Expression(metaclass=_Expression):
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
def find_all(self, *expression_types, bfs=True):
def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]:
"""
Returns a generator object which visits all nodes in this tree and only
yields those that match at least one of the specified expression types.
@ -321,7 +323,7 @@ class Expression(metaclass=_Expression):
if isinstance(expression, expression_types):
yield expression
def find_ancestor(self, *expression_types):
def find_ancestor(self, *expression_types: t.Type[E]) -> E | None:
"""
Returns a nearest parent matching expression_types.
@ -334,7 +336,8 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
return ancestor
# ignore type because mypy doesn't know that we're checking type in the loop
return ancestor # type: ignore[return-value]
@property
def parent_select(self):
@ -794,6 +797,7 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
"volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@ -883,7 +887,7 @@ class ByteString(Condition):
class Column(Condition):
arg_types = {"this": True, "table": False, "db": False, "catalog": False}
arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
@property
def table(self) -> str:
@ -926,6 +930,14 @@ class RenameTable(Expression):
pass
class SetTag(Expression):
arg_types = {"expressions": True, "unset": False}
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@ -2829,6 +2841,14 @@ class Div(Binary):
pass
class FloatDiv(Binary):
pass
class Overlaps(Binary):
pass
class Dot(Binary):
@property
def name(self) -> str:
@ -3125,6 +3145,10 @@ class ArrayFilter(Func):
_sql_names = ["FILTER", "ARRAY_FILTER"]
class ArrayJoin(Func):
arg_types = {"this": True, "expression": True, "null": False}
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
@ -3510,6 +3534,10 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
class RangeN(Func):
arg_types = {"this": True, "expressions": True, "each": False}
class ReadCSV(Func):
_sql_names = ["READ_CSV"]
is_var_len_args = True

View file

@ -109,6 +109,9 @@ class Generator:
# Whether or not create function uses an AS before the RETURN
CREATE_FUNCTION_RETURN_AS = True
# Whether or not to treat the division operator "/" as integer division
INTEGER_DIVISION = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -550,14 +553,17 @@ class Generator:
else:
expression_sql = f" AS{expression_sql}"
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
postindex_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_INDEX):
postindex_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]),
wrapped=False,
prefix=" ",
)
indexes = expression.args.get("indexes")
index_sql = ""
if indexes:
indexes_sql = []
indexes_sql: t.List[str] = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
@ -568,21 +574,24 @@ class Generator:
if index.args.get("columns")
else ""
)
if index.args.get("primary") and properties_locs.get(
exp.Properties.Location.POST_INDEX
):
postindex_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_INDEX]
),
wrapped=False,
)
ind_columns = f"{ind_columns} {postindex_props_sql}"
ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
if indexes_sql:
indexes_sql.append(ind_sql)
else:
indexes_sql.append(
f"{ind_sql}{postindex_props_sql}"
if index.args.get("primary")
else f"{postindex_props_sql}{ind_sql}"
)
indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
)
index_sql = "".join(indexes_sql)
else:
index_sql = postindex_props_sql
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
volatile = " VOLATILE" if expression.args.get("volatile") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
@ -593,7 +602,7 @@ class Generator:
wrapped=False,
)
modifiers = "".join((replace, unique, postcreate_props_sql))
modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@ -606,6 +615,7 @@ class Generator:
wrapped=False,
)
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
@ -1335,14 +1345,15 @@ class Generator:
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
def subquery_sql(self, expression: exp.Subquery) -> str:
def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
alias,
)
return self.prepend_ctes(expression, sql)
@ -1643,6 +1654,13 @@ class Generator:
def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
def comment_sql(self, expression: exp.Comment) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
expression_sql = self.sql(expression, "expression")
return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
def transaction_sql(self, *_) -> str:
return "BEGIN"
@ -1728,19 +1746,30 @@ class Generator:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
div = self.binary(expression, "/")
return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT")))
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
def div_sql(self, expression: exp.Div) -> str:
div = self.binary(expression, "/")
if not self.INTEGER_DIVISION:
return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT")))
return div
def floatdiv_sql(self, expression: exp.FloatDiv) -> str:
if self.INTEGER_DIVISION:
this = exp.Cast(this=expression.this, to=exp.DataType.build("DOUBLE"))
return self.div_sql(exp.Div(this=this, expression=expression.expression))
return self.binary(expression, "/")
def overlaps_sql(self, expression: exp.Overlaps) -> str:
return self.binary(expression, "OVERLAPS")
def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")

View file

@ -314,13 +314,27 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this:
return
expression = outer_scope.expression
if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
else:
outer_scope.expression.where(where.this, copy=False)
outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
for join in expression.args["joins"]:
source = join.alias_or_name
sources.add(source)
if source == from_or_join.alias_or_name:
break
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", simplify(from_or_join.args.get("on")))
return
expression.where(where.this, copy=False)
expression.set("where", simplify(expression.args.get("where")))
def _merge_order(outer_scope, inner_scope):

View file

@ -13,7 +13,7 @@ SELECT_ALL = object()
DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression, schema=None):
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
"""
Rewrite sqlglot AST to remove unused columns projections.
@ -26,6 +26,7 @@ def pushdown_projections(expression, schema=None):
Args:
expression (sqlglot.Expression): expression to optimize
remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression
"""
@ -57,7 +58,8 @@ def pushdown_projections(expression, schema=None):
]
if isinstance(scope.expression, exp.Select):
_remove_unused_selections(scope, parent_selections, schema)
if remove_unused_selections:
_remove_unused_selections(scope, parent_selections, schema)
# Group columns by source name
selects = defaultdict(set)

View file

@ -36,6 +36,10 @@ class _Parser(type):
klass = super().__new__(cls, clsname, bases, attrs)
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
if not klass.INTEGER_DIVISION:
klass.FACTOR = {**klass.FACTOR, TokenType.SLASH: exp.FloatDiv}
return klass
@ -157,6 +161,21 @@ class Parser(metaclass=_Parser):
RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT}
DB_CREATABLES = {
TokenType.DATABASE,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
}
CREATABLES = {
TokenType.COLUMN,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
*DB_CREATABLES,
}
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ANTI,
@ -168,8 +187,8 @@ class Parser(metaclass=_Parser):
TokenType.CACHE,
TokenType.CASCADE,
TokenType.COLLATE,
TokenType.COLUMN,
TokenType.COMMAND,
TokenType.COMMENT,
TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT,
@ -186,9 +205,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
TokenType.FUNCTION,
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
@ -211,13 +228,11 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
TokenType.SORTKEY,
TokenType.TABLE,
TokenType.TEMPORARY,
TokenType.TOP,
TokenType.TRAILING,
@ -226,10 +241,9 @@ class Parser(metaclass=_Parser):
TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.PROCEDURE,
TokenType.VIEW,
TokenType.VOLATILE,
TokenType.WINDOW,
*CREATABLES,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
@ -428,6 +442,7 @@ class Parser(metaclass=_Parser):
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.COMMENT: lambda self: self._parse_comment(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESC: lambda self: self._parse_describe(),
@ -490,6 +505,9 @@ class Parser(metaclass=_Parser):
TokenType.GLOB: lambda self, this: self._parse_escape(
self.expression(exp.Glob, this=this, expression=self._parse_bitwise())
),
TokenType.OVERLAPS: lambda self, this: self._parse_escape(
self.expression(exp.Overlaps, this=this, expression=self._parse_bitwise())
),
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape(
@ -628,6 +646,14 @@ class Parser(metaclass=_Parser):
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
}
ALTER_PARSERS = {
"ADD": lambda self: self._parse_alter_table_add(),
"ALTER": lambda self: self._parse_alter_table_alter(),
"DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
"DROP": lambda self: self._parse_alter_table_drop(),
"RENAME": lambda self: self._parse_alter_table_rename(),
}
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
@ -669,16 +695,6 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
TokenType.COLUMN,
TokenType.FUNCTION,
TokenType.INDEX,
TokenType.PROCEDURE,
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
}
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
@ -689,6 +705,8 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
INTEGER_DIVISION = True
__slots__ = (
"error_level",
"error_message_context",
@ -940,6 +958,32 @@ class Parser(metaclass=_Parser):
def _parse_command(self) -> exp.Expression:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
start = self._prev
exists = self._parse_exists() if allow_exists else None
self._match(TokenType.ON)
kind = self._match_set(self.CREATABLES) and self._prev
if not kind:
return self._parse_as_command(start)
if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=kind.token_type)
elif kind.token_type == TokenType.TABLE:
this = self._parse_table()
elif kind.token_type == TokenType.COLUMN:
this = self._parse_column()
else:
this = self._parse_id_var()
self._match(TokenType.IS)
return self.expression(
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
)
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@ -990,6 +1034,7 @@ class Parser(metaclass=_Parser):
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
volatile = self._match(TokenType.VOLATILE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@ -1028,11 +1073,7 @@ class Parser(metaclass=_Parser):
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
elif create_token.token_type in (
TokenType.TABLE,
TokenType.VIEW,
TokenType.SCHEMA,
):
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
# exp.Properties.Location.POST_NAME
@ -1100,11 +1141,12 @@ class Parser(metaclass=_Parser):
exp.Create,
this=this,
kind=create_token.text,
replace=replace,
unique=unique,
volatile=volatile,
expression=expression,
exists=exists,
properties=properties,
replace=replace,
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
@ -3648,6 +3690,47 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AddConstraint, this=this, expression=expression)
def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
index = self._index - 1
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
return self._parse_csv(self._parse_add_constraint)
self._retreat(index)
return self._parse_csv(self._parse_add_column)
def _parse_alter_table_alter(self) -> exp.Expression:
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, drop=True)
if self._match_pair(TokenType.SET, TokenType.DEFAULT):
return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
self._match_text_seq("SET", "DATA")
return self.expression(
exp.AlterColumn,
this=column,
dtype=self._match_text_seq("TYPE") and self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
index = self._index - 1
partition_exists = self._parse_exists()
if self._match(TokenType.PARTITION, advance=False):
return self._parse_csv(lambda: self._parse_drop_partition(exists=partition_exists))
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
def _parse_alter_table_rename(self) -> exp.Expression:
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return self._parse_as_command(self._prev)
@ -3655,50 +3738,12 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists()
this = self._parse_table(schema=True)
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
if not self._curr:
return None
index = self._index
if self._match(TokenType.DELETE):
actions = [self.expression(exp.Delete, where=self._parse_where())]
elif self._match_text_seq("ADD"):
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
actions = self._parse_csv(self._parse_add_constraint)
else:
self._retreat(index)
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP"):
partition_exists = self._parse_exists()
parser = self.ALTER_PARSERS.get(self._curr.text.upper())
actions = ensure_list(self._advance() or parser(self)) if parser else [] # type: ignore
if self._match(TokenType.PARTITION, advance=False):
actions = self._parse_csv(
lambda: self._parse_drop_partition(exists=partition_exists)
)
else:
self._retreat(index)
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
actions = self.expression(exp.AlterColumn, this=column, drop=True)
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
actions = self.expression(
exp.AlterColumn, this=column, default=self._parse_conjunction()
)
else:
self._match_text_seq("SET", "DATA")
actions = self.expression(
exp.AlterColumn,
this=column,
dtype=self._match_text_seq("TYPE") and self._parse_types(),
collate=self._match(TokenType.COLLATE) and self._parse_term(),
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
def _parse_show(self) -> t.Optional[exp.Expression]:
@ -3772,7 +3817,9 @@ class Parser(metaclass=_Parser):
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
self._advance()
return exp.Command(this=self._find_sql(start, self._prev))
text = self._find_sql(start, self._prev)
size = len(start.text)
return exp.Command(this=text[:size], expression=text[size:])
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict

View file

@ -60,6 +60,7 @@ class TokenType(AutoName):
STRING = auto()
NUMBER = auto()
IDENTIFIER = auto()
DATABASE = auto()
COLUMN = auto()
COLUMN_DEF = auto()
SCHEMA = auto()
@ -203,6 +204,7 @@ class TokenType(AutoName):
IS = auto()
ISNULL = auto()
JOIN = auto()
JOIN_MARKER = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
@ -235,6 +237,7 @@ class TokenType(AutoName):
OUTER = auto()
OUT_OF = auto()
OVER = auto()
OVERLAPS = auto()
OVERWRITE = auto()
PARTITION = auto()
PARTITION_BY = auto()
@ -491,6 +494,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CURRENT_DATE": TokenType.CURRENT_DATE,
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
"DATABASE": TokenType.DATABASE,
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
@ -564,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"OUTER": TokenType.OUTER,
"OUT OF": TokenType.OUT_OF,
"OVER": TokenType.OVER,
"OVERLAPS": TokenType.OVERLAPS,
"OVERWRITE": TokenType.OVERWRITE,
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
@ -652,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DOUBLE PRECISION": TokenType.DOUBLE,
"JSON": TokenType.JSON,
"CHAR": TokenType.CHAR,
"CHARACTER": TokenType.CHAR,
"NCHAR": TokenType.NCHAR,
"VARCHAR": TokenType.VARCHAR,
"VARCHAR2": TokenType.VARCHAR,
@ -687,8 +693,10 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"COMMENT": TokenType.COMMENT,
"COPY": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
"TRUNCATE": TokenType.COMMAND,