1
0
Fork 0

Merging upstream version 10.6.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:08:15 +01:00
parent fe1b1057f7
commit 2153103f81
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
36 changed files with 1007 additions and 270 deletions

View file

@ -33,7 +33,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.5.10"
__version__ = "10.6.0"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -94,10 +94,10 @@ class Column:
return self.inverse_binary_op(exp.Mod, other)
def __pow__(self, power: ColumnOrLiteral, modulo=None):
return Column(exp.Pow(this=self.expression, power=Column(power).expression))
return Column(exp.Pow(this=self.expression, expression=Column(power).expression))
def __rpow__(self, power: ColumnOrLiteral):
return Column(exp.Pow(this=Column(power).expression, power=self.expression))
return Column(exp.Pow(this=Column(power).expression, expression=self.expression))
def __invert__(self):
return self.unary_op(exp.Not)

View file

@ -311,7 +311,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2)
return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2)
def row_number() -> Column:

View file

@ -1,17 +1,14 @@
"""
## Dialects
One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a
"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`,
`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming
SQL code.
While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult
to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible
SQL transpilation framework.
However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such
example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to
override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new
dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class),
in order to implement the target behavior.
The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible.
Each SQL variation has its own `Dialect` subclass, extending the corresponding `Tokenizer`, `Parser` and `Generator`
classes as needed.
### Implementing a custom Dialect

View file

@ -169,6 +169,13 @@ class BigQuery(Dialect):
TokenType.VALUES,
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
"NOT DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
}
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore

View file

@ -66,12 +66,11 @@ def _sort_array_reverse(args):
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):
def _struct_sql(self, expression):
args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})"
return f"{{{', '.join(args)}}}"
def _datatype_sql(self, expression):
@ -153,7 +152,7 @@ class DuckDB(Dialect):
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_pack_sql,
exp.Struct: _struct_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,

View file

@ -251,7 +251,7 @@ class Hive(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
"WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties(
expressions=self._parse_wrapped_csv(self._parse_property)
),
}

View file

@ -202,7 +202,7 @@ class MySQL(Dialect):
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS, # type: ignore
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
"ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty),
}
STATEMENT_PARSERS = {

View file

@ -74,13 +74,16 @@ class Oracle(Dialect):
def query_modifiers(self, expression, *sqls):
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
*[self.sql(sql) for sql in expression.args.get("joins", [])],
*[self.sql(sql) for sql in expression.args.get("joins") or []],
self.sql(expression, "match"),
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
self.sql(expression, "window"),
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
@ -99,6 +102,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,

View file

@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
rename_func,
str_position_sql,
trim_sql,
)
@ -260,6 +261,16 @@ class Postgres(Dialect):
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
BITWISE = {
**parser.Parser.BITWISE, # type: ignore
TokenType.HASH: exp.BitwiseXor,
}
FACTOR = {
**parser.Parser.FACTOR, # type: ignore
TokenType.CARET: exp.Pow,
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@ -273,6 +284,7 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
@ -285,11 +297,13 @@ class Postgres(Dialect):
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.Pow: lambda self, e: self.binary(e, "^"),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,

View file

@ -174,6 +174,7 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"FROM_UNIXTIME": _from_unixtime,
"NOW": exp.CurrentTimestamp.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
@ -194,7 +195,6 @@ class Presto(Dialect):
FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = {exp.SchemaCommentProperty}

View file

@ -93,7 +93,7 @@ class Redshift(Postgres):
rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
selects = []
for i, row in enumerate(rows):
if i == 0:
if i == 0 and expression.alias:
row = [
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])

View file

@ -178,11 +178,6 @@ class Snowflake(Dialect):
),
}
PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPES = ["\\", "'"]
@ -195,6 +190,7 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@ -122,6 +122,7 @@ class Spark(Hive):
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),

View file

@ -230,6 +230,7 @@ class Expression(metaclass=_Expression):
Returns a deep copy of the expression.
"""
new = deepcopy(self)
new.parent = self.parent
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
item.parent = parent
@ -759,6 +760,10 @@ class Create(Expression):
"this": True,
"kind": True,
"expression": False,
"set": False,
"multiset": False,
"global_temporary": False,
"volatile": False,
"exists": False,
"properties": False,
"temporary": False,
@ -1082,7 +1087,7 @@ class LoadData(Expression):
class Partition(Expression):
pass
arg_types = {"expressions": True}
class Fetch(Expression):
@ -1232,6 +1237,18 @@ class Lateral(UDTF):
arg_types = {"this": True, "view": False, "outer": False, "alias": False}
class MatchRecognize(Expression):
arg_types = {
"partition_by": False,
"order": False,
"measures": False,
"rows": False,
"after": False,
"pattern": False,
"define": False,
}
# Clickhouse FROM FINAL modifier
# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier
class Final(Expression):
@ -1357,8 +1374,58 @@ class SerdeProperties(Property):
arg_types = {"expressions": True}
class FallbackProperty(Property):
arg_types = {"no": True, "protection": False}
class WithJournalTableProperty(Property):
arg_types = {"this": True}
class LogProperty(Property):
arg_types = {"no": True}
class JournalProperty(Property):
arg_types = {"no": True, "dual": False, "before": False}
class AfterJournalProperty(Property):
arg_types = {"no": True, "dual": False, "local": False}
class ChecksumProperty(Property):
arg_types = {"on": False, "default": False}
class FreespaceProperty(Property):
arg_types = {"this": True, "percent": False}
class MergeBlockRatioProperty(Property):
arg_types = {"this": False, "no": False, "default": False, "percent": False}
class DataBlocksizeProperty(Property):
arg_types = {"size": False, "units": False, "min": False, "default": False}
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
class IsolatedLoadingProperty(Property):
arg_types = {
"no": True,
"concurrent": True,
"for_all": True,
"for_insert": True,
"for_none": True,
}
class Properties(Expression):
arg_types = {"expressions": True}
arg_types = {"expressions": True, "before": False}
NAME_TO_PROPERTY = {
"AUTO_INCREMENT": AutoIncrementProperty,
@ -1510,6 +1577,7 @@ class Subqueryable(Unionable):
QUERY_MODIFIERS = {
"match": False,
"laterals": False,
"joins": False,
"pivots": False,
@ -2459,6 +2527,10 @@ class AddConstraint(Expression):
arg_types = {"this": False, "expression": False, "enforced": False}
class DropPartition(Expression):
arg_types = {"expressions": True, "exists": False}
# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@ -2540,6 +2612,10 @@ class Escape(Binary):
pass
class Glob(Binary, Predicate):
pass
class GT(Binary, Predicate):
pass
@ -3126,8 +3202,7 @@ class Posexplode(Func):
pass
class Pow(Func):
arg_types = {"this": True, "power": True}
class Pow(Binary, Func):
_sql_names = ["POWER", "POW"]
@ -3361,7 +3436,7 @@ class Year(Func):
class Use(Expression):
pass
arg_types = {"this": True, "kind": False}
class Merge(Expression):

View file

@ -65,6 +65,8 @@ class Generator:
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@ -97,6 +99,20 @@ class Generator:
STRUCT_DELIMITER = ("<", ">")
BEFORE_PROPERTIES = {
exp.FallbackProperty,
exp.WithJournalTableProperty,
exp.LogProperty,
exp.JournalProperty,
exp.AfterJournalProperty,
exp.ChecksumProperty,
exp.FreespaceProperty,
exp.MergeBlockRatioProperty,
exp.DataBlocksizeProperty,
exp.BlockCompressionProperty,
exp.IsolatedLoadingProperty,
}
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
@ -113,8 +129,6 @@ class Generator:
exp.TableFormatProperty,
}
WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint)
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@ -122,7 +136,6 @@ class Generator:
"time_mapping",
"time_trie",
"pretty",
"configured_pretty",
"quote_start",
"quote_end",
"identifier_start",
@ -177,7 +190,6 @@ class Generator:
self.time_mapping = time_mapping or {}
self.time_trie = time_trie
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.configured_pretty = self.pretty
self.quote_start = quote_start or "'"
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
@ -442,8 +454,20 @@ class Generator:
return "UNIQUE"
def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
has_before_properties = expression.args.get("properties")
has_before_properties = (
has_before_properties.args.get("before") if has_before_properties else None
)
if kind == "TABLE" and has_before_properties:
this_name = self.sql(expression.this, "this")
this_properties = self.sql(expression, "properties")
this_schema = f"({self.expressions(expression.this)})"
this = f"{this_name}, {this_properties} {this_schema}"
properties = ""
else:
this = self.sql(expression, "this")
properties = self.sql(expression, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
@ -456,7 +480,10 @@ class Generator:
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
set_ = " SET" if expression.args.get("set") else ""
multiset = " MULTISET" if expression.args.get("multiset") else ""
global_temporary = " GLOBAL TEMPORARY" if expression.args.get("global_temporary") else ""
volatile = " VOLATILE" if expression.args.get("volatile") else ""
data = expression.args.get("data")
if data is None:
data = ""
@ -475,7 +502,7 @@ class Generator:
indexes = expression.args.get("indexes")
index_sql = ""
if indexes is not None:
if indexes:
indexes_sql = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
@ -500,6 +527,10 @@ class Generator:
external,
unique,
materialized,
set_,
multiset,
global_temporary,
volatile,
)
)
no_schema_binding = (
@ -569,13 +600,14 @@ class Generator:
def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
sql = f"DELETE{this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
def drop_sql(self, expression: exp.Drop) -> str:
@ -630,28 +662,27 @@ class Generator:
return f"N{self.sql(expression, 'this')}"
def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
for prop in expression.this
]
)
return f"PARTITION({keys})"
return f"PARTITION({self.expressions(expression)})"
def properties_sql(self, expression: exp.Properties) -> str:
before_properties = []
root_properties = []
with_properties = []
for p in expression.expressions:
p_class = p.__class__
if p_class in self.WITH_PROPERTIES:
if p_class in self.BEFORE_PROPERTIES:
before_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
return (
self.properties(exp.Properties(expressions=before_properties), before=True)
+ self.root_properties(exp.Properties(expressions=root_properties))
+ self.with_properties(exp.Properties(expressions=with_properties))
)
def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
@ -659,13 +690,17 @@ class Generator:
return ""
def properties(
self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
self,
properties: exp.Properties,
prefix: str = "",
sep: str = ", ",
suffix: str = "",
before: bool = False,
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
return (
f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
)
expressions = expressions if before else self.wrap(expressions)
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
@ -687,6 +722,98 @@ class Generator:
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str:
no = "NO " if expression.args.get("no") else ""
protection = " PROTECTION" if expression.args.get("protection") else ""
return f"{no}FALLBACK{protection}"
def journalproperty_sql(self, expression: exp.JournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
dual = "DUAL " if expression.args.get("dual") else ""
before = "BEFORE " if expression.args.get("before") else ""
return f"{no}{dual}{before}JOURNAL"
def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str:
freespace = self.sql(expression, "this")
percent = " PERCENT" if expression.args.get("percent") else ""
return f"FREESPACE={freespace}{percent}"
def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str:
no = "NO " if expression.args.get("no") else ""
dual = "DUAL " if expression.args.get("dual") else ""
local = ""
if expression.args.get("local") is not None:
local = "LOCAL " if expression.args.get("local") else "NOT LOCAL "
return f"{no}{dual}{local}AFTER JOURNAL"
def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str:
if expression.args.get("default"):
property = "DEFAULT"
elif expression.args.get("on"):
property = "ON"
else:
property = "OFF"
return f"CHECKSUM={property}"
def mergeblockratioproperty_sql(self, expression: exp.MergeBlockRatioProperty) -> str:
if expression.args.get("no"):
return "NO MERGEBLOCKRATIO"
if expression.args.get("default"):
return "DEFAULT MERGEBLOCKRATIO"
percent = " PERCENT" if expression.args.get("percent") else ""
return f"MERGEBLOCKRATIO={self.sql(expression, 'this')}{percent}"
def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str:
default = expression.args.get("default")
min = expression.args.get("min")
if default is not None or min is not None:
if default:
property = "DEFAULT"
elif min:
property = "MINIMUM"
else:
property = "MAXIMUM"
return f"{property} DATABLOCKSIZE"
else:
units = expression.args.get("units")
units = f" {units}" if units else ""
return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}"
def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str:
autotemp = expression.args.get("autotemp")
always = expression.args.get("always")
default = expression.args.get("default")
manual = expression.args.get("manual")
never = expression.args.get("never")
if autotemp is not None:
property = f"AUTOTEMP({self.expressions(autotemp)})"
elif always:
property = "ALWAYS"
elif default:
property = "DEFAULT"
elif manual:
property = "MANUAL"
elif never:
property = "NEVER"
return f"BLOCKCOMPRESSION={property}"
def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str:
no = expression.args.get("no")
no = " NO" if no else ""
concurrent = expression.args.get("concurrent")
concurrent = " CONCURRENT" if concurrent else ""
for_ = ""
if expression.args.get("for_all"):
for_ = " FOR ALL"
elif expression.args.get("for_insert"):
for_ = " FOR INSERT"
elif expression.args.get("for_none"):
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
@ -833,10 +960,21 @@ class Generator:
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
)
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
cube = expression.args.get("cube")
if cube is True:
cube = self.seg("WITH CUBE")
else:
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = expression.args.get("rollup")
if rollup is True:
rollup = self.seg("WITH ROLLUP")
else:
rollup = self.expressions(expression, key="rollup", indent=False)
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}"
def having_sql(self, expression: exp.Having) -> str:
@ -980,10 +1118,37 @@ class Generator:
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
order = self.sql(expression, "order")
measures = self.sql(expression, "measures")
measures = self.seg(f"MEASURES {measures}") if measures else ""
rows = self.sql(expression, "rows")
rows = self.seg(rows) if rows else ""
after = self.sql(expression, "after")
after = self.seg(after) if after else ""
pattern = self.sql(expression, "pattern")
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
define = self.sql(expression, "define")
define = self.seg(f"DEFINE {define}") if define else ""
body = "".join(
(
partition,
order,
measures,
rows,
after,
pattern,
define,
)
)
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
self.sql(expression, "match"),
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
@ -1092,8 +1257,7 @@ class Generator:
def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True)
partition = f"PARTITION BY {partition}" if partition else ""
partition = self.partition_by_sql(expression)
order = expression.args.get("order")
order_sql = self.order_sql(order, flat=True) if order else ""
@ -1113,6 +1277,10 @@ class Generator:
return f"{this} ({window_args.strip()})"
def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str:
partition = self.expressions(expression, key="partition_by", flat=True)
return f"PARTITION BY {partition}" if partition else ""
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
@ -1386,16 +1554,19 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION):
actions = self.sql(actions[0])
elif isinstance(actions[0], exp.Delete):
actions = self.expressions(expression, "actions", flat=True)
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
actions = self.expressions(expression, "actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
def droppartition_sql(self, expression: exp.DropPartition) -> str:
expressions = self.expressions(expression)
exists = " IF EXISTS " if expression.args.get("exists") else " "
return f"DROP{exists}{expressions}"
def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
this = self.sql(expression, "this")
expression_ = self.sql(expression, "expression")
@ -1447,6 +1618,9 @@ class Generator:
def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
def glob_sql(self, expression: exp.Glob) -> str:
return self.binary(expression, "GLOB")
def gt_sql(self, expression: exp.GT) -> str:
return self.binary(expression, ">")
@ -1499,7 +1673,11 @@ class Generator:
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
def use_sql(self, expression: exp.Use) -> str:
return f"USE {self.sql(expression, 'this')}"
kind = self.sql(expression, "kind")
kind = f" {kind}" if kind else ""
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"USE{kind}{this}"
def binary(self, expression: exp.Binary, op: str) -> str:
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"

View file

@ -2,6 +2,14 @@ from sqlglot import exp
def expand_multi_table_selects(expression):
"""
Replace multiple FROM expressions with JOINs.
Example:
>>> from sqlglot import parse_one
>>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql()
'SELECT * FROM x CROSS JOIN y'
"""
for from_ in expression.find_all(exp.From):
parent = from_.parent

View file

@ -11,7 +11,7 @@ def isolate_table_selects(expression, schema=None):
if len(scope.selected_sources) == 1:
continue
for (_, source) in scope.selected_sources.values():
for _, source in scope.selected_sources.values():
if not isinstance(source, exp.Table) or not schema.column_names(source):
continue

View file

@ -6,6 +6,11 @@ from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
"""
Removes cross joins if possible and reorder joins based on predicate dependencies.
Example:
>>> from sqlglot import parse_one
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
for select in expression.find_all(exp.Select):
references = {}

View file

@ -64,7 +64,6 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {

View file

@ -175,13 +175,9 @@ class Parser(metaclass=_Parser):
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
TokenType.DIV,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
TokenType.END,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
TokenType.FALSE,
TokenType.FIRST,
@ -194,13 +190,10 @@ class Parser(metaclass=_Parser):
TokenType.IF,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.IMMUTABLE,
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LANGUAGE,
TokenType.LEADING,
TokenType.LOCAL,
TokenType.LOCATION,
TokenType.MATERIALIZED,
TokenType.MERGE,
TokenType.NATURAL,
@ -209,13 +202,11 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
TokenType.PERCENT,
TokenType.PIVOT,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RETURNS,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
@ -225,10 +216,7 @@ class Parser(metaclass=_Parser):
TokenType.SET,
TokenType.SHOW,
TokenType.SORTKEY,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
TokenType.TOP,
TokenType.TRAILING,
@ -237,7 +225,6 @@ class Parser(metaclass=_Parser):
TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
TokenType.VIEW,
TokenType.VOLATILE,
@ -448,7 +435,12 @@ class Parser(metaclass=_Parser):
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
TokenType.USE: lambda self: self.expression(
exp.Use,
kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"))
and exp.Var(this=self._prev.text),
this=self._parse_table(schema=False),
),
}
UNARY_PARSERS = {
@ -492,6 +484,9 @@ class Parser(metaclass=_Parser):
RANGE_PARSERS = {
TokenType.BETWEEN: lambda self, this: self._parse_between(this),
TokenType.GLOB: lambda self, this: self._parse_escape(
self.expression(exp.Glob, this=this, expression=self._parse_bitwise())
),
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape(
@ -512,45 +507,66 @@ class Parser(metaclass=_Parser):
}
PROPERTY_PARSERS = {
TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment(
exp.AutoIncrementProperty
),
TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty),
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment(
exp.SchemaCommentProperty
),
TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.DISTKEY: lambda self: self._parse_distkey(),
TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty),
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.ROW: lambda self: self._parse_row(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
exp.TableFormatProperty
),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
TokenType.DETERMINISTIC: lambda self: self.expression(
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"CHARACTER SET": lambda self: self._parse_character_set(),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"DISTKEY": lambda self: self._parse_distkey(),
"DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty),
"SORTKEY": lambda self: self._parse_sortkey(),
"LIKE": lambda self: self._parse_create_like(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
"EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty),
"DETERMINISTIC": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.IMMUTABLE: lambda self: self.expression(
"IMMUTABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
TokenType.STABLE: lambda self: self.expression(
"STABLE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("STABLE")
),
TokenType.VOLATILE: lambda self: self.expression(
"VOLATILE": lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
"WITH": lambda self: self._parse_with_property(),
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"),
"LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"),
"BEFORE": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"JOURNAL": lambda self: self._parse_journal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"AFTER": lambda self: self._parse_afterjournal(
no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL"
),
"LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True),
"NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False),
"CHECKSUM": lambda self: self._parse_checksum(),
"FREESPACE": lambda self: self._parse_freespace(),
"MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio(
no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT"
),
"MIN": lambda self: self._parse_datablocksize(),
"MINIMUM": lambda self: self._parse_datablocksize(),
"MAX": lambda self: self._parse_datablocksize(),
"MAXIMUM": lambda self: self._parse_datablocksize(),
"DATABLOCKSIZE": lambda self: self._parse_datablocksize(
default=self._prev.text.upper() == "DEFAULT"
),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
}
CONSTRAINT_PARSERS = {
@ -580,6 +596,7 @@ class Parser(metaclass=_Parser):
}
QUERY_MODIFIER_PARSERS = {
"match": lambda self: self._parse_match_recognize(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
@ -627,7 +644,6 @@ class Parser(metaclass=_Parser):
"max_errors",
"null_ordering",
"_tokens",
"_chunks",
"_index",
"_curr",
"_next",
@ -660,7 +676,6 @@ class Parser(metaclass=_Parser):
self.sql = ""
self.errors = []
self._tokens = []
self._chunks = [[]]
self._index = 0
self._curr = None
self._next = None
@ -728,17 +743,18 @@ class Parser(metaclass=_Parser):
self.reset()
self.sql = sql or ""
total = len(raw_tokens)
chunks: t.List[t.List[Token]] = [[]]
for i, token in enumerate(raw_tokens):
if token.token_type == TokenType.SEMICOLON:
if i < total - 1:
self._chunks.append([])
chunks.append([])
else:
self._chunks[-1].append(token)
chunks[-1].append(token)
expressions = []
for tokens in self._chunks:
for tokens in chunks:
self._index = -1
self._tokens = tokens
self._advance()
@ -771,7 +787,7 @@ class Parser(metaclass=_Parser):
error level setting.
"""
token = token or self._curr or self._prev or Token.string("")
start = self._find_token(token, self.sql)
start = self._find_token(token)
end = start + len(token.text)
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
@ -833,13 +849,16 @@ class Parser(metaclass=_Parser):
for error_message in expression.error_messages(args):
self.raise_error(error_message)
def _find_token(self, token: Token, sql: str) -> int:
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[self._find_token(start) : self._find_token(end)]
def _find_token(self, token: Token) -> int:
line = 1
col = 1
index = 0
while line < token.line or col < token.col:
if Tokenizer.WHITE_SPACE.get(sql[index]) == TokenType.BREAK:
if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK:
line += 1
col = 1
else:
@ -911,6 +930,10 @@ class Parser(metaclass=_Parser):
def _parse_create(self) -> t.Optional[exp.Expression]:
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata
global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata
volatile = self._match(TokenType.VOLATILE) # Teradata
temporary = self._match(TokenType.TEMPORARY)
transient = self._match_text_seq("TRANSIENT")
external = self._match_text_seq("EXTERNAL")
@ -954,10 +977,18 @@ class Parser(metaclass=_Parser):
TokenType.VIEW,
TokenType.SCHEMA,
):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
expression = self._parse_ddl_select()
table_parts = self._parse_table_parts(schema=True)
if self._match(TokenType.COMMA): # comma-separated properties before schema definition
properties = self._parse_properties(before=True)
this = self._parse_schema(this=table_parts)
if not properties: # properties after schema definition
properties = self._parse_properties()
self._match(TokenType.ALIAS)
expression = self._parse_ddl_select()
if create_token.token_type == TokenType.TABLE:
if self._match_text_seq("WITH", "DATA"):
@ -988,6 +1019,10 @@ class Parser(metaclass=_Parser):
this=this,
kind=create_token.text,
expression=expression,
set=set_,
multiset=multiset,
global_temporary=global_temporary,
volatile=volatile,
exists=exists,
properties=properties,
temporary=temporary,
@ -1004,9 +1039,19 @@ class Parser(metaclass=_Parser):
begin=begin,
)
def _parse_property_before(self) -> t.Optional[exp.Expression]:
self._match_text_seq("NO")
self._match_text_seq("DUAL")
self._match_text_seq("DEFAULT")
if self.PROPERTY_PARSERS.get(self._curr.text.upper()):
return self.PROPERTY_PARSERS[self._curr.text.upper()](self)
return None
def _parse_property(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
if self._match_texts(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
return self._parse_character_set(True)
@ -1033,6 +1078,166 @@ class Parser(metaclass=_Parser):
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
def _parse_properties(self, before=None) -> t.Optional[exp.Expression]:
properties = []
while True:
if before:
self._match(TokenType.COMMA)
identified_property = self._parse_property_before()
else:
identified_property = self._parse_property()
if not identified_property:
break
for p in ensure_collection(identified_property):
properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties, before=before)
return None
def _parse_fallback(self, no=False) -> exp.Expression:
self._match_text_seq("FALLBACK")
return self.expression(
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
def _parse_with_property(
self,
) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
if not self._next:
return None
if self._next.text.upper() == "JOURNAL":
return self._parse_withjournaltable()
return self._parse_withisolatedloading()
def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE")
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
def _parse_log(self, no=False) -> exp.Expression:
self._match_text_seq("LOG")
return self.expression(exp.LogProperty, no=no)
def _parse_journal(self, no=False, dual=False) -> exp.Expression:
before = self._match_text_seq("BEFORE")
self._match_text_seq("JOURNAL")
return self.expression(exp.JournalProperty, no=no, dual=dual, before=before)
def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression:
self._match_text_seq("NOT")
self._match_text_seq("LOCAL")
self._match_text_seq("AFTER", "JOURNAL")
return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local)
def _parse_checksum(self) -> exp.Expression:
self._match_text_seq("CHECKSUM")
self._match(TokenType.EQ)
on = None
if self._match(TokenType.ON):
on = True
elif self._match_text_seq("OFF"):
on = False
default = self._match(TokenType.DEFAULT)
return self.expression(
exp.ChecksumProperty,
on=on,
default=default,
)
def _parse_freespace(self) -> exp.Expression:
self._match_text_seq("FREESPACE")
self._match(TokenType.EQ)
return self.expression(
exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT)
)
def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression:
self._match_text_seq("MERGEBLOCKRATIO")
if self._match(TokenType.EQ):
return self.expression(
exp.MergeBlockRatioProperty,
this=self._parse_number(),
percent=self._match(TokenType.PERCENT),
)
else:
return self.expression(
exp.MergeBlockRatioProperty,
no=no,
default=default,
)
def _parse_datablocksize(self, default=None) -> exp.Expression:
if default:
self._match_text_seq("DATABLOCKSIZE")
return self.expression(exp.DataBlocksizeProperty, default=True)
elif self._match_texts(("MIN", "MINIMUM")):
self._match_text_seq("DATABLOCKSIZE")
return self.expression(exp.DataBlocksizeProperty, min=True)
elif self._match_texts(("MAX", "MAXIMUM")):
self._match_text_seq("DATABLOCKSIZE")
return self.expression(exp.DataBlocksizeProperty, min=False)
self._match_text_seq("DATABLOCKSIZE")
self._match(TokenType.EQ)
size = self._parse_number()
units = None
if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")):
units = self._prev.text
return self.expression(exp.DataBlocksizeProperty, size=size, units=units)
def _parse_blockcompression(self) -> exp.Expression:
self._match_text_seq("BLOCKCOMPRESSION")
self._match(TokenType.EQ)
always = self._match(TokenType.ALWAYS)
manual = self._match_text_seq("MANUAL")
never = self._match_text_seq("NEVER")
default = self._match_text_seq("DEFAULT")
autotemp = None
if self._match_text_seq("AUTOTEMP"):
autotemp = self._parse_schema()
return self.expression(
exp.BlockCompressionProperty,
always=always,
manual=manual,
never=never,
default=default,
autotemp=autotemp,
)
def _parse_withisolatedloading(self) -> exp.Expression:
self._match(TokenType.WITH)
no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT")
self._match_text_seq("ISOLATED", "LOADING")
for_all = self._match_text_seq("FOR", "ALL")
for_insert = self._match_text_seq("FOR", "INSERT")
for_none = self._match_text_seq("FOR", "NONE")
return self.expression(
exp.IsolatedLoadingProperty,
no=no,
concurrent=concurrent,
for_all=for_all,
for_insert=for_insert,
for_none=for_none,
)
def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
return []
def _parse_partitioned_by(self) -> exp.Expression:
self._match(TokenType.EQ)
return self.expression(
@ -1093,21 +1298,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
def _parse_properties(self) -> t.Optional[exp.Expression]:
properties = []
while True:
identified_property = self._parse_property()
if not identified_property:
break
for p in ensure_collection(identified_property):
properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties)
return None
def _parse_describe(self) -> exp.Expression:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
@ -1248,11 +1438,9 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.PARTITION):
return None
def parse_values() -> exp.Property:
props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
return self.expression(
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction)
)
def _parse_value(self) -> exp.Expression:
if self._match(TokenType.L_PAREN):
@ -1360,8 +1548,7 @@ class Parser(metaclass=_Parser):
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
if not self._match(TokenType.ALIAS):
self.raise_error("Expected AS in CTE")
self._match(TokenType.ALIAS)
return self.expression(
exp.CTE,
@ -1376,10 +1563,11 @@ class Parser(metaclass=_Parser):
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
)
index = self._index
if self._match(TokenType.L_PAREN):
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
self._match_r_paren()
self._match_r_paren() if columns else self._retreat(index)
else:
columns = None
@ -1452,6 +1640,87 @@ class Parser(metaclass=_Parser):
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
self._match_l_paren()
partition = self._parse_partition_by()
order = self._parse_order()
measures = (
self._parse_alias(self._parse_conjunction())
if self._match_text_seq("MEASURES")
else None
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
rows = exp.Var(this="ONE ROW PER MATCH")
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
text += f" SHOW EMPTY MATCHES"
elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"):
text += f" OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
text += f" WITH UNMATCHED ROWS"
rows = exp.Var(this=text)
else:
rows = None
if self._match_text_seq("AFTER", "MATCH", "SKIP"):
text = "AFTER MATCH SKIP"
if self._match_text_seq("PAST", "LAST", "ROW"):
text += f" PAST LAST ROW"
elif self._match_text_seq("TO", "NEXT", "ROW"):
text += f" TO NEXT ROW"
elif self._match_text_seq("TO", "FIRST"):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
text += f" TO LAST {self._advance_any().text}" # type: ignore
after = exp.Var(this=text)
else:
after = None
if self._match_text_seq("PATTERN"):
self._match_l_paren()
if not self._curr:
self.raise_error("Expecting )", self._curr)
paren = 1
start = self._curr
while self._curr and paren > 0:
if self._curr.token_type == TokenType.L_PAREN:
paren += 1
if self._curr.token_type == TokenType.R_PAREN:
paren -= 1
self._advance()
if paren > 0:
self.raise_error("Expecting )", self._curr)
if not self._curr:
self.raise_error("Expecting pattern", self._curr)
end = self._prev
pattern = exp.Var(this=self._find_sql(start, end))
else:
pattern = None
define = (
self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None
)
self._match_r_paren()
return self.expression(
exp.MatchRecognize,
partition_by=partition,
order=order,
measures=measures,
rows=rows,
after=after,
pattern=pattern,
define=define,
)
def _parse_lateral(self) -> t.Optional[exp.Expression]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
@ -1772,12 +2041,19 @@ class Parser(metaclass=_Parser):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
expressions = self._parse_csv(self._parse_conjunction)
grouping_sets = self._parse_grouping_sets()
with_ = self._match(TokenType.WITH)
cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars())
rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars())
return self.expression(
exp.Group,
expressions=self._parse_csv(self._parse_conjunction),
grouping_sets=self._parse_grouping_sets(),
cube=self._match(TokenType.CUBE) and self._parse_wrapped_id_vars(),
rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(),
expressions=expressions,
grouping_sets=grouping_sets,
cube=cube,
rollup=rollup,
)
def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
@ -1788,11 +2064,11 @@ class Parser(metaclass=_Parser):
def _parse_grouping_set(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
grouping_set = self._parse_csv(self._parse_id_var)
grouping_set = self._parse_csv(self._parse_column)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=grouping_set)
return self._parse_id_var()
return self._parse_column()
def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_having_token and not self._match(TokenType.HAVING):
@ -2268,7 +2544,6 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(self._parse_lambda)
if function:
# Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
# second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
if count_params(function) == 2:
@ -2541,9 +2816,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET):
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
bracket_kind = self._prev.token_type
expressions: t.List[t.Optional[exp.Expression]]
if self._match(TokenType.COLON):
@ -2551,14 +2827,19 @@ class Parser(metaclass=_Parser):
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
if not this or this.name.upper() == "ARRAY":
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
this = self.expression(exp.Struct, expressions=expressions)
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
expressions = apply_index_offset(expressions, -self.index_offset)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET):
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
self.raise_error("Expected ]")
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
self.raise_error("Expected }")
this.comments = self._prev_comments
return self._parse_bracket(this)
@ -2727,7 +3008,7 @@ class Parser(metaclass=_Parser):
position = self._prev.text.upper()
expression = self._parse_term()
if self._match(TokenType.FROM):
if self._match_set((TokenType.FROM, TokenType.COMMA)):
this = self._parse_term()
else:
this = expression
@ -2792,14 +3073,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
partition = None
if self._match(TokenType.PARTITION_BY):
partition = self._parse_csv(self._parse_conjunction)
partition = self._parse_partition_by()
order = self._parse_order()
spec = None
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
if kind:
@ -2816,6 +3091,8 @@ class Parser(metaclass=_Parser):
end=end["value"],
end_side=end["side"],
)
else:
spec = None
self._match_r_paren()
@ -3060,6 +3337,12 @@ class Parser(metaclass=_Parser):
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
return self.expression(
exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists
)
def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
this = None
kind = self._prev.token_type
@ -3092,14 +3375,24 @@ class Parser(metaclass=_Parser):
actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
index = self._index
if self._match_text_seq("ADD"):
if self._match(TokenType.DELETE):
actions = [self.expression(exp.Delete, where=self._parse_where())]
elif self._match_text_seq("ADD"):
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
actions = self._parse_csv(self._parse_add_constraint)
else:
self._retreat(index)
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("DROP"):
partition_exists = self._parse_exists()
if self._match(TokenType.PARTITION, advance=False):
actions = self._parse_csv(
lambda: self._parse_drop_partition(exists=partition_exists)
)
else:
self._retreat(index)
actions = self._parse_csv(self._parse_drop_column)
elif self._match_text_seq("RENAME", "TO"):
actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"):

View file

@ -22,6 +22,7 @@ class TokenType(AutoName):
DCOLON = auto()
SEMICOLON = auto()
STAR = auto()
BACKSLASH = auto()
SLASH = auto()
LT = auto()
LTE = auto()
@ -157,18 +158,14 @@ class TokenType(AutoName):
DELETE = auto()
DESC = auto()
DESCRIBE = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
DISTKEY = auto()
DISTRIBUTE_BY = auto()
DISTSTYLE = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
ENCODE = auto()
END = auto()
ENGINE = auto()
ESCAPE = auto()
EXCEPT = auto()
EXECUTE = auto()
@ -182,10 +179,11 @@ class TokenType(AutoName):
FOR = auto()
FOREIGN_KEY = auto()
FORMAT = auto()
FROM = auto()
FULL = auto()
FUNCTION = auto()
FROM = auto()
GENERATED = auto()
GLOB = auto()
GLOBAL = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
@ -195,7 +193,6 @@ class TokenType(AutoName):
IF = auto()
IGNORE_NULLS = auto()
ILIKE = auto()
IMMUTABLE = auto()
IN = auto()
INDEX = auto()
INNER = auto()
@ -217,8 +214,8 @@ class TokenType(AutoName):
LIMIT = auto()
LOAD_DATA = auto()
LOCAL = auto()
LOCATION = auto()
MAP = auto()
MATCH_RECOGNIZE = auto()
MATERIALIZED = auto()
MERGE = auto()
MOD = auto()
@ -242,7 +239,6 @@ class TokenType(AutoName):
OVERWRITE = auto()
PARTITION = auto()
PARTITION_BY = auto()
PARTITIONED_BY = auto()
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
@ -258,7 +254,6 @@ class TokenType(AutoName):
REPLACE = auto()
RESPECT_NULLS = auto()
REFERENCES = auto()
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
ROLLBACK = auto()
@ -277,10 +272,7 @@ class TokenType(AutoName):
SOME = auto()
SORTKEY = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
STRUCT = auto()
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TOP = auto()
@ -414,6 +406,7 @@ class Tokenizer(metaclass=_Tokenizer):
"+": TokenType.PLUS,
";": TokenType.SEMICOLON,
"/": TokenType.SLASH,
"\\": TokenType.BACKSLASH,
"*": TokenType.STAR,
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
@ -448,9 +441,11 @@ class Tokenizer(metaclass=_Tokenizer):
},
**{
f"{prefix}{key}": TokenType.BLOCK_END
for key in ("}}", "%}", "#}")
for key in ("%}", "#}")
for prefix in ("", "+", "-")
},
"+}}": TokenType.BLOCK_END,
"-}}": TokenType.BLOCK_END,
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@ -503,17 +498,13 @@ class Tokenizer(metaclass=_Tokenizer):
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
"DESCRIBE": TokenType.DESCRIBE,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTKEY": TokenType.DISTKEY,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DISTSTYLE": TokenType.DISTSTYLE,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
"END": TokenType.END,
"ENGINE": TokenType.ENGINE,
"ESCAPE": TokenType.ESCAPE,
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
@ -530,13 +521,13 @@ class Tokenizer(metaclass=_Tokenizer):
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
"GENERATED": TokenType.GENERATED,
"GLOB": TokenType.GLOB,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
"IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
"IGNORE NULLS": TokenType.IGNORE_NULLS,
"IN": TokenType.IN,
"INDEX": TokenType.INDEX,
@ -548,7 +539,6 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"LANGUAGE": TokenType.LANGUAGE,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
@ -557,7 +547,6 @@ class Tokenizer(metaclass=_Tokenizer):
"LIMIT": TokenType.LIMIT,
"LOAD DATA": TokenType.LOAD_DATA,
"LOCAL": TokenType.LOCAL,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
"MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
@ -582,8 +571,8 @@ class Tokenizer(metaclass=_Tokenizer):
"OVERWRITE": TokenType.OVERWRITE,
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
"PARTITIONED BY": TokenType.PARTITIONED_BY,
"PARTITIONED_BY": TokenType.PARTITIONED_BY,
"PARTITIONED BY": TokenType.PARTITION_BY,
"PARTITIONED_BY": TokenType.PARTITION_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
"PRECEDING": TokenType.PRECEDING,
@ -596,7 +585,6 @@ class Tokenizer(metaclass=_Tokenizer):
"REPLACE": TokenType.REPLACE,
"RESPECT NULLS": TokenType.RESPECT_NULLS,
"REFERENCES": TokenType.REFERENCES,
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
"ROLLBACK": TokenType.ROLLBACK,
@ -613,11 +601,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME,
"SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
"TABLE": TokenType.TABLE,
"TABLE_FORMAT": TokenType.TABLE_FORMAT,
"TBLPROPERTIES": TokenType.PROPERTIES,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,

View file

@ -27,20 +27,18 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
e.alias: (i, e.this)
e.alias: i
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
}
expression = expression.copy()
top_level_expression = None
for item, parent, _ in expression.walk(bfs=False):
top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression
if isinstance(item, exp.Column) and not item.table:
alias_index, col_expression = aliased_selects.get(item.name, (None, None))
if alias_index and top_level_expression != col_expression:
item.replace(exp.Literal.number(alias_index))
for group_by in expression.expressions:
if (
isinstance(group_by, exp.Column)
and not group_by.table
and group_by.name in aliased_selects
):
group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
return expression
@ -63,22 +61,21 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
outer_selects = [e.copy() for e in expression.expressions]
nested = expression.copy()
nested.args["distinct"].pop()
distinct_cols = expression.args["distinct"].args["on"].expressions
expression.args["distinct"].pop()
outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
order = nested.args.get("order")
order = expression.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window = exp.alias_(window, row_number)
nested.select(window, copy=False)
return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
expression.select(window, copy=False)
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
return expression
@ -120,7 +117,7 @@ def preprocess(
"""
def _to_sql(self, expression):
expression = transforms[0](expression)
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
return to_sql(self, expression)