Merging upstream version 10.6.3.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
d03a55eda6
commit
ece6881255
48 changed files with 906 additions and 266 deletions
10
README.md
10
README.md
|
@ -1,12 +1,12 @@
|
||||||
# SQLGlot
|
# SQLGlot
|
||||||
|
|
||||||
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
|
||||||
|
|
||||||
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python.
|
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks), while being written purely in Python.
|
||||||
|
|
||||||
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
|
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
|
||||||
|
|
||||||
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed.
|
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that SQL validation is not SQLGlot’s goal, so some syntax errors may go unnoticed.
|
||||||
|
|
||||||
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
|
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
|
||||||
|
|
||||||
|
@ -432,6 +432,8 @@ user_id price
|
||||||
2 3.0
|
2 3.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
See also: [Writing a Python SQL engine from scratch](https://github.com/tobymao/sqlglot/blob/main/posts/python_sql_engine.md).
|
||||||
|
|
||||||
## Used By
|
## Used By
|
||||||
* [Fugue](https://github.com/fugue-project/fugue)
|
* [Fugue](https://github.com/fugue-project/fugue)
|
||||||
* [ibis](https://github.com/ibis-project/ibis)
|
* [ibis](https://github.com/ibis-project/ibis)
|
||||||
|
@ -442,7 +444,7 @@ user_id price
|
||||||
|
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation:
|
SQLGlot uses [pdoc](https://pdoc.dev/) to serve its API documentation:
|
||||||
|
|
||||||
```
|
```
|
||||||
make docs-serve
|
make docs-serve
|
||||||
|
|
|
@ -33,7 +33,13 @@ from sqlglot.parser import Parser
|
||||||
from sqlglot.schema import MappingSchema, Schema
|
from sqlglot.schema import MappingSchema, Schema
|
||||||
from sqlglot.tokens import Tokenizer, TokenType
|
from sqlglot.tokens import Tokenizer, TokenType
|
||||||
|
|
||||||
__version__ = "10.6.0"
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.dialects.dialect import DialectType
|
||||||
|
|
||||||
|
T = t.TypeVar("T", bound=Expression)
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = "10.6.3"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
"""Whether to format generated SQL by default."""
|
"""Whether to format generated SQL by default."""
|
||||||
|
@ -42,9 +48,7 @@ schema = MappingSchema()
|
||||||
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
|
||||||
|
|
||||||
|
|
||||||
def parse(
|
def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
|
||||||
sql: str, read: t.Optional[str | Dialect] = None, **opts
|
|
||||||
) -> t.List[t.Optional[Expression]]:
|
|
||||||
"""
|
"""
|
||||||
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
|
||||||
|
|
||||||
|
@ -60,9 +64,57 @@ def parse(
|
||||||
return dialect.parse(sql, **opts)
|
return dialect.parse(sql, **opts)
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
def parse_one(
|
def parse_one(
|
||||||
sql: str,
|
sql: str,
|
||||||
read: t.Optional[str | Dialect] = None,
|
read: None = None,
|
||||||
|
into: t.Type[T] = ...,
|
||||||
|
**opts,
|
||||||
|
) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
read: DialectType,
|
||||||
|
into: t.Type[T],
|
||||||
|
**opts,
|
||||||
|
) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
read: None = None,
|
||||||
|
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
|
||||||
|
**opts,
|
||||||
|
) -> Expression:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
read: DialectType,
|
||||||
|
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
|
||||||
|
**opts,
|
||||||
|
) -> Expression:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
**opts,
|
||||||
|
) -> Expression:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def parse_one(
|
||||||
|
sql: str,
|
||||||
|
read: DialectType = None,
|
||||||
into: t.Optional[exp.IntoType] = None,
|
into: t.Optional[exp.IntoType] = None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> Expression:
|
) -> Expression:
|
||||||
|
@ -96,8 +148,8 @@ def parse_one(
|
||||||
|
|
||||||
def transpile(
|
def transpile(
|
||||||
sql: str,
|
sql: str,
|
||||||
read: t.Optional[str | Dialect] = None,
|
read: DialectType = None,
|
||||||
write: t.Optional[str | Dialect] = None,
|
write: DialectType = None,
|
||||||
identity: bool = True,
|
identity: bool = True,
|
||||||
error_level: t.Optional[ErrorLevel] = None,
|
error_level: t.Optional[ErrorLevel] = None,
|
||||||
**opts,
|
**opts,
|
||||||
|
|
|
@ -260,11 +260,7 @@ class Column:
|
||||||
"""
|
"""
|
||||||
if isinstance(dataType, DataType):
|
if isinstance(dataType, DataType):
|
||||||
dataType = dataType.simpleString()
|
dataType = dataType.simpleString()
|
||||||
new_expression = exp.Cast(
|
return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
|
||||||
this=self.column_expression,
|
|
||||||
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
|
|
||||||
)
|
|
||||||
return Column(new_expression)
|
|
||||||
|
|
||||||
def startswith(self, value: t.Union[str, Column]) -> Column:
|
def startswith(self, value: t.Union[str, Column]) -> Column:
|
||||||
value = self._lit(value) if not isinstance(value, Column) else value
|
value = self._lit(value) if not isinstance(value, Column) else value
|
||||||
|
|
|
@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column:
|
||||||
|
|
||||||
|
|
||||||
def dayofweek(col: ColumnOrName) -> Column:
|
def dayofweek(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "DAYOFWEEK")
|
return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
|
||||||
|
|
||||||
|
|
||||||
def dayofmonth(col: ColumnOrName) -> Column:
|
def dayofmonth(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "DAYOFMONTH")
|
return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
|
||||||
|
|
||||||
|
|
||||||
def dayofyear(col: ColumnOrName) -> Column:
|
def dayofyear(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "DAYOFYEAR")
|
return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
|
||||||
|
|
||||||
|
|
||||||
def hour(col: ColumnOrName) -> Column:
|
def hour(col: ColumnOrName) -> Column:
|
||||||
|
@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column:
|
||||||
|
|
||||||
|
|
||||||
def weekofyear(col: ColumnOrName) -> Column:
|
def weekofyear(col: ColumnOrName) -> Column:
|
||||||
return Column.invoke_anonymous_function(col, "WEEKOFYEAR")
|
return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
|
||||||
|
|
||||||
|
|
||||||
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
|
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
|
||||||
|
@ -1144,10 +1144,16 @@ def aggregate(
|
||||||
merge_exp = _get_lambda_from_func(merge)
|
merge_exp = _get_lambda_from_func(merge)
|
||||||
if finish is not None:
|
if finish is not None:
|
||||||
finish_exp = _get_lambda_from_func(finish)
|
finish_exp = _get_lambda_from_func(finish)
|
||||||
return Column.invoke_anonymous_function(
|
return Column.invoke_expression_over_column(
|
||||||
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
|
col,
|
||||||
|
glotexp.Reduce,
|
||||||
|
initial=initialValue,
|
||||||
|
merge=Column(merge_exp),
|
||||||
|
finish=Column(finish_exp),
|
||||||
|
)
|
||||||
|
return Column.invoke_expression_over_column(
|
||||||
|
col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
|
||||||
)
|
)
|
||||||
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
|
|
||||||
|
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
|
|
|
@ -222,14 +222,6 @@ class BigQuery(Dialect):
|
||||||
exp.DataType.Type.NVARCHAR: "STRING",
|
exp.DataType.Type.NVARCHAR: "STRING",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
|
||||||
exp.LanguageProperty,
|
|
||||||
exp.ReturnsProperty,
|
|
||||||
exp.VolatilityProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
WITH_PROPERTIES = {exp.Property}
|
|
||||||
|
|
||||||
EXPLICIT_UNION = True
|
EXPLICIT_UNION = True
|
||||||
|
|
||||||
def array_sql(self, expression: exp.Array) -> str:
|
def array_sql(self, expression: exp.Array) -> str:
|
||||||
|
|
|
@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
|
||||||
def get_or_raise(cls, dialect):
|
def get_or_raise(cls, dialect):
|
||||||
if not dialect:
|
if not dialect:
|
||||||
return cls
|
return cls
|
||||||
|
if isinstance(dialect, _Dialect):
|
||||||
|
return dialect
|
||||||
|
if isinstance(dialect, Dialect):
|
||||||
|
return dialect.__class__
|
||||||
|
|
||||||
result = cls.get(dialect)
|
result = cls.get(dialect)
|
||||||
if not result:
|
if not result:
|
||||||
raise ValueError(f"Unknown dialect '{dialect}'")
|
raise ValueError(f"Unknown dialect '{dialect}'")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||||
|
|
||||||
|
|
||||||
def rename_func(name):
|
def rename_func(name):
|
||||||
def _rename(self, expression):
|
def _rename(self, expression):
|
||||||
args = flatten(expression.args.values())
|
args = flatten(expression.args.values())
|
||||||
|
|
|
@ -137,7 +137,10 @@ class Drill(Dialect):
|
||||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {exp.PartitionedByProperty}
|
PROPERTIES_LOCATION = {
|
||||||
|
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
|
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**generator.Generator.TRANSFORMS, # type: ignore
|
**generator.Generator.TRANSFORMS, # type: ignore
|
||||||
|
|
|
@ -20,10 +20,6 @@ from sqlglot.helper import seq_get
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
|
||||||
def _unix_to_time(self, expression):
|
|
||||||
return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
|
|
||||||
|
|
||||||
|
|
||||||
def _str_to_time_sql(self, expression):
|
def _str_to_time_sql(self, expression):
|
||||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||||
|
|
||||||
|
@ -113,7 +109,7 @@ class DuckDB(Dialect):
|
||||||
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||||
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
|
||||||
"STRUCT_PACK": exp.Struct.from_arg_list,
|
"STRUCT_PACK": exp.Struct.from_arg_list,
|
||||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
|
||||||
"UNNEST": exp.Explode.from_arg_list,
|
"UNNEST": exp.Explode.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,9 +158,9 @@ class DuckDB(Dialect):
|
||||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
|
||||||
exp.TsOrDsAdd: _ts_or_ds_add,
|
exp.TsOrDsAdd: _ts_or_ds_add,
|
||||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||||
exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
|
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||||
exp.UnixToTime: _unix_to_time,
|
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
|
||||||
exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
|
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
|
|
@ -322,17 +322,11 @@ class Hive(Dialect):
|
||||||
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
exp.LastDateOfMonth: rename_func("LAST_DAY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_PROPERTIES = {exp.Property}
|
PROPERTIES_LOCATION = {
|
||||||
|
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
ROOT_PROPERTIES = {
|
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.PartitionedByProperty,
|
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.FileFormatProperty,
|
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.SchemaCommentProperty,
|
|
||||||
exp.LocationProperty,
|
|
||||||
exp.TableFormatProperty,
|
|
||||||
exp.RowFormatDelimitedProperty,
|
|
||||||
exp.RowFormatSerdeProperty,
|
|
||||||
exp.SerdeProperties,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_properties(self, properties):
|
def with_properties(self, properties):
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing as t
|
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
|
@ -98,6 +96,8 @@ def _date_add_sql(kind):
|
||||||
|
|
||||||
|
|
||||||
class MySQL(Dialect):
|
class MySQL(Dialect):
|
||||||
|
time_format = "'%Y-%m-%d %T'"
|
||||||
|
|
||||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||||
time_mapping = {
|
time_mapping = {
|
||||||
"%M": "%B",
|
"%M": "%B",
|
||||||
|
@ -110,6 +110,7 @@ class MySQL(Dialect):
|
||||||
"%u": "%W",
|
"%u": "%W",
|
||||||
"%k": "%-H",
|
"%k": "%-H",
|
||||||
"%l": "%-I",
|
"%l": "%-I",
|
||||||
|
"%T": "%H:%M:%S",
|
||||||
}
|
}
|
||||||
|
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
|
@ -428,6 +429,7 @@ class MySQL(Dialect):
|
||||||
)
|
)
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
LOCKING_READS_SUPPORTED = True
|
||||||
NULL_ORDERING_SUPPORTED = False
|
NULL_ORDERING_SUPPORTED = False
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
|
@ -449,23 +451,12 @@ class MySQL(Dialect):
|
||||||
exp.StrPosition: strposition_to_locate_sql,
|
exp.StrPosition: strposition_to_locate_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
|
||||||
exp.EngineProperty,
|
|
||||||
exp.AutoIncrementProperty,
|
|
||||||
exp.CharacterSetProperty,
|
|
||||||
exp.CollateProperty,
|
|
||||||
exp.SchemaCommentProperty,
|
|
||||||
exp.LikeProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
|
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
|
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
|
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
|
||||||
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
|
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
|
||||||
|
|
||||||
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
|
|
||||||
|
|
||||||
def show_sql(self, expression):
|
def show_sql(self, expression):
|
||||||
this = f" {expression.name}"
|
this = f" {expression.name}"
|
||||||
full = " FULL" if expression.args.get("full") else ""
|
full = " FULL" if expression.args.get("full") else ""
|
||||||
|
|
|
@ -44,6 +44,8 @@ class Oracle(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
LOCKING_READS_SUPPORTED = True
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||||
exp.DataType.Type.TINYINT: "NUMBER",
|
exp.DataType.Type.TINYINT: "NUMBER",
|
||||||
|
@ -69,6 +71,7 @@ class Oracle(Dialect):
|
||||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
|
||||||
|
exp.Substring: rename_func("SUBSTR"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def query_modifiers(self, expression, *sqls):
|
def query_modifiers(self, expression, *sqls):
|
||||||
|
@ -90,6 +93,7 @@ class Oracle(Dialect):
|
||||||
self.sql(expression, "order"),
|
self.sql(expression, "order"),
|
||||||
self.sql(expression, "offset"), # offset before limit in oracle
|
self.sql(expression, "offset"), # offset before limit in oracle
|
||||||
self.sql(expression, "limit"),
|
self.sql(expression, "limit"),
|
||||||
|
self.sql(expression, "lock"),
|
||||||
sep="",
|
sep="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -148,6 +148,22 @@ def _serial_to_generated(expression):
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_series(args):
|
||||||
|
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
|
||||||
|
step = seq_get(args, 2)
|
||||||
|
|
||||||
|
if step is None:
|
||||||
|
# Postgres allows calls with just two arguments -- the "step" argument defaults to 1
|
||||||
|
return exp.GenerateSeries.from_arg_list(args)
|
||||||
|
|
||||||
|
if step.is_string:
|
||||||
|
args[2] = exp.to_interval(step.this)
|
||||||
|
elif isinstance(step, exp.Interval) and not step.args.get("unit"):
|
||||||
|
args[2] = exp.to_interval(step.this.this)
|
||||||
|
|
||||||
|
return exp.GenerateSeries.from_arg_list(args)
|
||||||
|
|
||||||
|
|
||||||
def _to_timestamp(args):
|
def _to_timestamp(args):
|
||||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
|
@ -195,29 +211,6 @@ class Postgres(Dialect):
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||||
|
|
||||||
CREATABLES = (
|
|
||||||
"AGGREGATE",
|
|
||||||
"CAST",
|
|
||||||
"CONVERSION",
|
|
||||||
"COLLATION",
|
|
||||||
"DEFAULT CONVERSION",
|
|
||||||
"CONSTRAINT",
|
|
||||||
"DOMAIN",
|
|
||||||
"EXTENSION",
|
|
||||||
"FOREIGN",
|
|
||||||
"FUNCTION",
|
|
||||||
"OPERATOR",
|
|
||||||
"POLICY",
|
|
||||||
"ROLE",
|
|
||||||
"RULE",
|
|
||||||
"SEQUENCE",
|
|
||||||
"TEXT",
|
|
||||||
"TRIGGER",
|
|
||||||
"TYPE",
|
|
||||||
"UNLOGGED",
|
|
||||||
"USER",
|
|
||||||
)
|
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"~~": TokenType.LIKE,
|
"~~": TokenType.LIKE,
|
||||||
|
@ -243,8 +236,6 @@ class Postgres(Dialect):
|
||||||
"TEMP": TokenType.TEMPORARY,
|
"TEMP": TokenType.TEMPORARY,
|
||||||
"UUID": TokenType.UUID,
|
"UUID": TokenType.UUID,
|
||||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
|
||||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
|
||||||
}
|
}
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
|
@ -257,8 +248,10 @@ class Postgres(Dialect):
|
||||||
|
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS, # type: ignore
|
**parser.Parser.FUNCTIONS, # type: ignore
|
||||||
|
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||||
"TO_TIMESTAMP": _to_timestamp,
|
"TO_TIMESTAMP": _to_timestamp,
|
||||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||||
|
"GENERATE_SERIES": _generate_series,
|
||||||
}
|
}
|
||||||
|
|
||||||
BITWISE = {
|
BITWISE = {
|
||||||
|
@ -272,6 +265,8 @@ class Postgres(Dialect):
|
||||||
}
|
}
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
LOCKING_READS_SUPPORTED = True
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||||
|
|
|
@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression):
|
||||||
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
||||||
|
|
||||||
|
|
||||||
|
def _sequence_sql(self, expression):
|
||||||
|
start = expression.args["start"]
|
||||||
|
end = expression.args["end"]
|
||||||
|
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
|
||||||
|
|
||||||
|
target_type = None
|
||||||
|
|
||||||
|
if isinstance(start, exp.Cast):
|
||||||
|
target_type = start.to
|
||||||
|
elif isinstance(end, exp.Cast):
|
||||||
|
target_type = end.to
|
||||||
|
|
||||||
|
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
|
||||||
|
to = target_type.copy()
|
||||||
|
|
||||||
|
if target_type is start.to:
|
||||||
|
end = exp.Cast(this=end, to=to)
|
||||||
|
else:
|
||||||
|
start = exp.Cast(this=start, to=to)
|
||||||
|
|
||||||
|
return f"SEQUENCE({self.format_args(start, end, step)})"
|
||||||
|
|
||||||
|
|
||||||
def _ensure_utf8(charset):
|
def _ensure_utf8(charset):
|
||||||
if charset.name.lower() != "utf-8":
|
if charset.name.lower() != "utf-8":
|
||||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||||
|
@ -145,7 +168,7 @@ def _from_unixtime(args):
|
||||||
class Presto(Dialect):
|
class Presto(Dialect):
|
||||||
index_offset = 1
|
index_offset = 1
|
||||||
null_ordering = "nulls_are_last"
|
null_ordering = "nulls_are_last"
|
||||||
time_format = "'%Y-%m-%d %H:%i:%S'"
|
time_format = MySQL.time_format # type: ignore
|
||||||
time_mapping = MySQL.time_mapping # type: ignore
|
time_mapping = MySQL.time_mapping # type: ignore
|
||||||
|
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
|
@ -197,7 +220,10 @@ class Presto(Dialect):
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
STRUCT_DELIMITER = ("(", ")")
|
STRUCT_DELIMITER = ("(", ")")
|
||||||
|
|
||||||
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
|
PROPERTIES_LOCATION = {
|
||||||
|
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
|
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
|
||||||
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||||
|
@ -223,6 +249,7 @@ class Presto(Dialect):
|
||||||
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||||
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||||
|
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||||
exp.DataType: _datatype_sql,
|
exp.DataType: _datatype_sql,
|
||||||
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||||
|
@ -231,6 +258,7 @@ class Presto(Dialect):
|
||||||
exp.Decode: _decode_sql,
|
exp.Decode: _decode_sql,
|
||||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||||
exp.Encode: _encode_sql,
|
exp.Encode: _encode_sql,
|
||||||
|
exp.GenerateSeries: _sequence_sql,
|
||||||
exp.Hex: rename_func("TO_HEX"),
|
exp.Hex: rename_func("TO_HEX"),
|
||||||
exp.If: if_sql,
|
exp.If: if_sql,
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
|
|
|
@ -61,14 +61,9 @@ class Redshift(Postgres):
|
||||||
exp.DataType.Type.INT: "INTEGER",
|
exp.DataType.Type.INT: "INTEGER",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
PROPERTIES_LOCATION = {
|
||||||
exp.DistKeyProperty,
|
**Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
exp.SortKeyProperty,
|
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||||
exp.DistStyleProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
WITH_PROPERTIES = {
|
|
||||||
exp.LikeProperty,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
|
|
|
@ -234,15 +234,6 @@ class Snowflake(Dialect):
|
||||||
"replace": "RENAME",
|
"replace": "RENAME",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROOT_PROPERTIES = {
|
|
||||||
exp.PartitionedByProperty,
|
|
||||||
exp.ReturnsProperty,
|
|
||||||
exp.LanguageProperty,
|
|
||||||
exp.SchemaCommentProperty,
|
|
||||||
exp.ExecuteAsProperty,
|
|
||||||
exp.VolatilityProperty,
|
|
||||||
}
|
|
||||||
|
|
||||||
def except_op(self, expression):
|
def except_op(self, expression):
|
||||||
if not expression.args.get("distinct", False):
|
if not expression.args.get("distinct", False):
|
||||||
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
self.unsupported("EXCEPT with All is not supported in Snowflake")
|
||||||
|
|
|
@ -73,6 +73,19 @@ class Spark(Hive):
|
||||||
),
|
),
|
||||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||||
"IIF": exp.If.from_arg_list,
|
"IIF": exp.If.from_arg_list,
|
||||||
|
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||||
|
"DAYOFWEEK": lambda args: exp.DayOfWeek(
|
||||||
|
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||||
|
),
|
||||||
|
"DAYOFMONTH": lambda args: exp.DayOfMonth(
|
||||||
|
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||||
|
),
|
||||||
|
"DAYOFYEAR": lambda args: exp.DayOfYear(
|
||||||
|
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||||
|
),
|
||||||
|
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||||
|
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION_PARSERS = {
|
FUNCTION_PARSERS = {
|
||||||
|
@ -105,6 +118,14 @@ class Spark(Hive):
|
||||||
exp.DataType.Type.BIGINT: "LONG",
|
exp.DataType.Type.BIGINT: "LONG",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROPERTIES_LOCATION = {
|
||||||
|
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
|
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
|
||||||
|
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
|
||||||
|
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
|
||||||
|
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
|
||||||
|
}
|
||||||
|
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
**Hive.Generator.TRANSFORMS, # type: ignore
|
**Hive.Generator.TRANSFORMS, # type: ignore
|
||||||
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
|
||||||
|
@ -126,11 +147,27 @@ class Spark(Hive):
|
||||||
exp.VariancePop: rename_func("VAR_POP"),
|
exp.VariancePop: rename_func("VAR_POP"),
|
||||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||||
|
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||||
|
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||||
|
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||||
|
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||||
|
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||||
}
|
}
|
||||||
TRANSFORMS.pop(exp.ArraySort)
|
TRANSFORMS.pop(exp.ArraySort)
|
||||||
TRANSFORMS.pop(exp.ILike)
|
TRANSFORMS.pop(exp.ILike)
|
||||||
|
|
||||||
WRAP_DERIVED_VALUES = False
|
WRAP_DERIVED_VALUES = False
|
||||||
|
|
||||||
|
def cast_sql(self, expression: exp.Cast) -> str:
|
||||||
|
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||||
|
exp.DataType.Type.JSON
|
||||||
|
):
|
||||||
|
schema = f"'{self.sql(expression, 'to')}'"
|
||||||
|
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
|
||||||
|
if expression.to.is_type(exp.DataType.Type.JSON):
|
||||||
|
return f"TO_JSON({self.sql(expression, 'this')})"
|
||||||
|
|
||||||
|
return super(Spark.Generator, self).cast_sql(expression)
|
||||||
|
|
||||||
class Tokenizer(Hive.Tokenizer):
|
class Tokenizer(Hive.Tokenizer):
|
||||||
HEX_STRINGS = [("X'", "'")]
|
HEX_STRINGS = [("X'", "'")]
|
||||||
|
|
|
@ -31,6 +31,5 @@ class Tableau(Dialect):
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
FUNCTIONS = {
|
FUNCTIONS = {
|
||||||
**parser.Parser.FUNCTIONS, # type: ignore
|
**parser.Parser.FUNCTIONS, # type: ignore
|
||||||
"IFNULL": exp.Coalesce.from_arg_list,
|
|
||||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,14 @@ class Teradata(Dialect):
|
||||||
)
|
)
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
PROPERTIES_LOCATION = {
|
||||||
|
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
|
||||||
|
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
|
||||||
|
}
|
||||||
|
|
||||||
|
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
|
||||||
|
return f"PARTITION BY {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
# FROM before SET in Teradata UPDATE syntax
|
# FROM before SET in Teradata UPDATE syntax
|
||||||
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
|
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
|
||||||
def update_sql(self, expression: exp.Update) -> str:
|
def update_sql(self, expression: exp.Update) -> str:
|
||||||
|
|
|
@ -412,6 +412,8 @@ class TSQL(Dialect):
|
||||||
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
|
LOCKING_READS_SUPPORTED = True
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||||
exp.DataType.Type.BOOLEAN: "BIT",
|
exp.DataType.Type.BOOLEAN: "BIT",
|
||||||
|
|
|
@ -14,10 +14,6 @@ from sqlglot import Dialect
|
||||||
from sqlglot import expressions as exp
|
from sqlglot import expressions as exp
|
||||||
from sqlglot.helper import ensure_collection
|
from sqlglot.helper import ensure_collection
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
|
||||||
T = t.TypeVar("T")
|
|
||||||
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Insert:
|
class Insert:
|
||||||
|
@ -56,6 +52,11 @@ class Keep:
|
||||||
target: exp.Expression
|
target: exp.Expression
|
||||||
|
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
T = t.TypeVar("T")
|
||||||
|
Edit = t.Union[Insert, Remove, Move, Update, Keep]
|
||||||
|
|
||||||
|
|
||||||
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
|
||||||
"""
|
"""
|
||||||
Returns the list of changes between the source and the target expressions.
|
Returns the list of changes between the source and the target expressions.
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
|
"""
|
||||||
|
.. include:: ../../posts/python_sql_engine.md
|
||||||
|
----
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import maybe_parse
|
from sqlglot import maybe_parse
|
||||||
from sqlglot.errors import ExecuteError
|
from sqlglot.errors import ExecuteError
|
||||||
|
@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.dialects.dialect import DialectType
|
||||||
|
from sqlglot.executor.table import Tables
|
||||||
|
from sqlglot.expressions import Expression
|
||||||
|
from sqlglot.schema import Schema
|
||||||
|
|
||||||
def execute(sql, schema=None, read=None, tables=None):
|
|
||||||
|
def execute(
|
||||||
|
sql: str | Expression,
|
||||||
|
schema: t.Optional[t.Dict | Schema] = None,
|
||||||
|
read: DialectType = None,
|
||||||
|
tables: t.Optional[t.Dict] = None,
|
||||||
|
) -> Table:
|
||||||
"""
|
"""
|
||||||
Run a sql query against data.
|
Run a sql query against data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sql (str|sqlglot.Expression): a sql statement
|
sql: a sql statement.
|
||||||
schema (dict|sqlglot.optimizer.Schema): database schema.
|
schema: database schema.
|
||||||
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
|
This can either be an instance of `Schema` or a mapping in one of the following forms:
|
||||||
the following forms:
|
|
||||||
1. {table: {col: type}}
|
1. {table: {col: type}}
|
||||||
2. {db: {table: {col: type}}}
|
2. {db: {table: {col: type}}}
|
||||||
3. {catalog: {db: {table: {col: type}}}}
|
3. {catalog: {db: {table: {col: type}}}}
|
||||||
read (str): the SQL dialect to apply during parsing
|
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
|
||||||
(eg. "spark", "hive", "presto", "mysql").
|
tables: additional tables to register.
|
||||||
tables (dict): additional tables to register.
|
|
||||||
Returns:
|
Returns:
|
||||||
sqlglot.executor.Table: Simple columnar data structure.
|
Simple columnar data structure.
|
||||||
"""
|
"""
|
||||||
tables = ensure_tables(tables)
|
tables_ = ensure_tables(tables)
|
||||||
|
|
||||||
if not schema:
|
if not schema:
|
||||||
schema = {
|
schema = {
|
||||||
name: {column: type(table[0][column]).__name__ for column in table.columns}
|
name: {column: type(table[0][column]).__name__ for column in table.columns}
|
||||||
for name, table in tables.mapping.items()
|
for name, table in tables_.mapping.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = ensure_schema(schema)
|
schema = ensure_schema(schema)
|
||||||
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
|
|
||||||
|
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
|
||||||
raise ExecuteError("Tables must support the same table args as schema")
|
raise ExecuteError("Tables must support the same table args as schema")
|
||||||
|
|
||||||
expression = maybe_parse(sql, dialect=read)
|
expression = maybe_parse(sql, dialect=read)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
expression = optimize(expression, schema, leave_tables_isolated=True)
|
expression = optimize(expression, schema, leave_tables_isolated=True)
|
||||||
|
|
||||||
logger.debug("Optimization finished: %f", time.time() - now)
|
logger.debug("Optimization finished: %f", time.time() - now)
|
||||||
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
|
||||||
|
|
||||||
plan = Plan(expression)
|
plan = Plan(expression)
|
||||||
|
|
||||||
logger.debug("Logical Plan: %s", plan)
|
logger.debug("Logical Plan: %s", plan)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
result = PythonExecutor(tables=tables).execute(plan)
|
result = PythonExecutor(tables=tables_).execute(plan)
|
||||||
|
|
||||||
logger.debug("Query finished: %f", time.time() - now)
|
logger.debug("Query finished: %f", time.time() - now)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -171,5 +171,6 @@ ENV = {
|
||||||
"STRPOSITION": str_position,
|
"STRPOSITION": str_position,
|
||||||
"SUB": null_if_any(lambda e, this: e - this),
|
"SUB": null_if_any(lambda e, this: e - this),
|
||||||
"SUBSTRING": substring,
|
"SUBSTRING": substring,
|
||||||
|
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
|
||||||
"UPPER": null_if_any(lambda arg: arg.upper()),
|
"UPPER": null_if_any(lambda arg: arg.upper()),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot.helper import dict_depth
|
from sqlglot.helper import dict_depth
|
||||||
from sqlglot.schema import AbstractMappingSchema
|
from sqlglot.schema import AbstractMappingSchema
|
||||||
|
|
||||||
|
@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def ensure_tables(d: dict | None) -> Tables:
|
def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
|
||||||
return Tables(_ensure_tables(d))
|
return Tables(_ensure_tables(d))
|
||||||
|
|
||||||
|
|
||||||
def _ensure_tables(d: dict | None) -> dict:
|
def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
|
||||||
if not d:
|
if not d:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict:
|
||||||
columns = tuple(table[0]) if table else ()
|
columns = tuple(table[0]) if table else ()
|
||||||
rows = [tuple(row[c] for c in columns) for row in table]
|
rows = [tuple(row[c] for c in columns) for row in table]
|
||||||
result[name] = Table(columns=columns, rows=rows)
|
result[name] = Table(columns=columns, rows=rows)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -32,13 +32,7 @@ from sqlglot.helper import (
|
||||||
from sqlglot.tokens import Token
|
from sqlglot.tokens import Token
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.dialects.dialect import Dialect
|
from sqlglot.dialects.dialect import DialectType
|
||||||
|
|
||||||
IntoType = t.Union[
|
|
||||||
str,
|
|
||||||
t.Type[Expression],
|
|
||||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class _Expression(type):
|
class _Expression(type):
|
||||||
|
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self._to_s()
|
return self._to_s()
|
||||||
|
|
||||||
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
|
def sql(self, dialect: DialectType = None, **opts) -> str:
|
||||||
"""
|
"""
|
||||||
Returns SQL string representation of this tree.
|
Returns SQL string representation of this tree.
|
||||||
|
|
||||||
|
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
|
||||||
return load(obj)
|
return load(obj)
|
||||||
|
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
IntoType = t.Union[
|
||||||
|
str,
|
||||||
|
t.Type[Expression],
|
||||||
|
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Condition(Expression):
|
class Condition(Expression):
|
||||||
def and_(self, *expressions, dialect=None, **opts):
|
def and_(self, *expressions, dialect=None, **opts):
|
||||||
"""
|
"""
|
||||||
|
@ -1285,6 +1287,18 @@ class Property(Expression):
|
||||||
arg_types = {"this": True, "value": True}
|
arg_types = {"this": True, "value": True}
|
||||||
|
|
||||||
|
|
||||||
|
class AlgorithmProperty(Property):
|
||||||
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
|
class DefinerProperty(Property):
|
||||||
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
|
||||||
|
class SqlSecurityProperty(Property):
|
||||||
|
arg_types = {"definer": True}
|
||||||
|
|
||||||
|
|
||||||
class TableFormatProperty(Property):
|
class TableFormatProperty(Property):
|
||||||
arg_types = {"this": True}
|
arg_types = {"this": True}
|
||||||
|
|
||||||
|
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
|
||||||
|
|
||||||
|
|
||||||
class Properties(Expression):
|
class Properties(Expression):
|
||||||
arg_types = {"expressions": True, "before": False}
|
arg_types = {"expressions": True}
|
||||||
|
|
||||||
NAME_TO_PROPERTY = {
|
NAME_TO_PROPERTY = {
|
||||||
|
"ALGORITHM": AlgorithmProperty,
|
||||||
"AUTO_INCREMENT": AutoIncrementProperty,
|
"AUTO_INCREMENT": AutoIncrementProperty,
|
||||||
"CHARACTER SET": CharacterSetProperty,
|
"CHARACTER SET": CharacterSetProperty,
|
||||||
"COLLATE": CollateProperty,
|
"COLLATE": CollateProperty,
|
||||||
"COMMENT": SchemaCommentProperty,
|
"COMMENT": SchemaCommentProperty,
|
||||||
|
"DEFINER": DefinerProperty,
|
||||||
"DISTKEY": DistKeyProperty,
|
"DISTKEY": DistKeyProperty,
|
||||||
"DISTSTYLE": DistStyleProperty,
|
"DISTSTYLE": DistStyleProperty,
|
||||||
"ENGINE": EngineProperty,
|
"ENGINE": EngineProperty,
|
||||||
|
@ -1447,6 +1463,14 @@ class Properties(Expression):
|
||||||
|
|
||||||
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
|
||||||
|
|
||||||
|
class Location(AutoName):
|
||||||
|
POST_CREATE = auto()
|
||||||
|
PRE_SCHEMA = auto()
|
||||||
|
POST_INDEX = auto()
|
||||||
|
POST_SCHEMA_ROOT = auto()
|
||||||
|
POST_SCHEMA_WITH = auto()
|
||||||
|
UNSUPPORTED = auto()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, properties_dict) -> Properties:
|
def from_dict(cls, properties_dict) -> Properties:
|
||||||
expressions = []
|
expressions = []
|
||||||
|
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
|
||||||
"order": False,
|
"order": False,
|
||||||
"limit": False,
|
"limit": False,
|
||||||
"offset": False,
|
"offset": False,
|
||||||
|
"lock": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1713,6 +1738,12 @@ class Schema(Expression):
|
||||||
arg_types = {"this": False, "expressions": False}
|
arg_types = {"this": False, "expressions": False}
|
||||||
|
|
||||||
|
|
||||||
|
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
|
||||||
|
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
|
||||||
|
class Lock(Expression):
|
||||||
|
arg_types = {"update": True}
|
||||||
|
|
||||||
|
|
||||||
class Select(Subqueryable):
|
class Select(Subqueryable):
|
||||||
arg_types = {
|
arg_types = {
|
||||||
"with": False,
|
"with": False,
|
||||||
|
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
|
||||||
properties=properties_expression,
|
properties=properties_expression,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def lock(self, update: bool = True, copy: bool = True) -> Select:
|
||||||
|
"""
|
||||||
|
Set the locking read mode for this expression.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
|
||||||
|
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
|
||||||
|
|
||||||
|
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
|
||||||
|
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
|
||||||
|
copy: if `False`, modify this expression instance in-place.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The modified expression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
inst = _maybe_copy(self, copy)
|
||||||
|
inst.set("lock", Lock(update=update))
|
||||||
|
|
||||||
|
return inst
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_selects(self) -> t.List[str]:
|
def named_selects(self) -> t.List[str]:
|
||||||
return [e.output_name for e in self.expressions if e.alias_or_name]
|
return [e.output_name for e in self.expressions if e.alias_or_name]
|
||||||
|
@ -2456,24 +2511,28 @@ class DataType(Expression):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
|
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
|
||||||
) -> DataType:
|
) -> DataType:
|
||||||
from sqlglot import parse_one
|
from sqlglot import parse_one
|
||||||
|
|
||||||
if isinstance(dtype, str):
|
if isinstance(dtype, str):
|
||||||
data_type_exp: t.Optional[Expression]
|
|
||||||
if dtype.upper() in cls.Type.__members__:
|
if dtype.upper() in cls.Type.__members__:
|
||||||
data_type_exp = DataType(this=DataType.Type[dtype.upper()])
|
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
|
||||||
else:
|
else:
|
||||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||||
if data_type_exp is None:
|
if data_type_exp is None:
|
||||||
raise ValueError(f"Unparsable data type value: {dtype}")
|
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||||
elif isinstance(dtype, DataType.Type):
|
elif isinstance(dtype, DataType.Type):
|
||||||
data_type_exp = DataType(this=dtype)
|
data_type_exp = DataType(this=dtype)
|
||||||
|
elif isinstance(dtype, DataType):
|
||||||
|
return dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
|
||||||
return DataType(**{**data_type_exp.args, **kwargs})
|
return DataType(**{**data_type_exp.args, **kwargs})
|
||||||
|
|
||||||
|
def is_type(self, dtype: DataType.Type) -> bool:
|
||||||
|
return self.this == dtype
|
||||||
|
|
||||||
|
|
||||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||||
class PseudoType(Expression):
|
class PseudoType(Expression):
|
||||||
|
@ -2840,6 +2899,10 @@ class Array(Func):
|
||||||
is_var_len_args = True
|
is_var_len_args = True
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateSeries(Func):
|
||||||
|
arg_types = {"start": True, "end": True, "step": False}
|
||||||
|
|
||||||
|
|
||||||
class ArrayAgg(AggFunc):
|
class ArrayAgg(AggFunc):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -2909,6 +2972,9 @@ class Cast(Func):
|
||||||
def output_name(self):
|
def output_name(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
def is_type(self, dtype: DataType.Type) -> bool:
|
||||||
|
return self.to.is_type(dtype)
|
||||||
|
|
||||||
|
|
||||||
class Collate(Binary):
|
class Collate(Binary):
|
||||||
pass
|
pass
|
||||||
|
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
|
||||||
arg_types = {"this": True, "unit": True, "zone": False}
|
arg_types = {"this": True, "unit": True, "zone": False}
|
||||||
|
|
||||||
|
|
||||||
|
class DayOfWeek(Func):
|
||||||
|
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
|
||||||
|
|
||||||
|
|
||||||
|
class DayOfMonth(Func):
|
||||||
|
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
|
||||||
|
|
||||||
|
|
||||||
|
class DayOfYear(Func):
|
||||||
|
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
|
||||||
|
|
||||||
|
|
||||||
|
class WeekOfYear(Func):
|
||||||
|
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
||||||
|
|
||||||
|
|
||||||
class LastDateOfMonth(Func):
|
class LastDateOfMonth(Func):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
|
||||||
|
|
||||||
|
|
||||||
class Reduce(Func):
|
class Reduce(Func):
|
||||||
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
|
arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
|
||||||
|
|
||||||
|
|
||||||
class RegexpLike(Func):
|
class RegexpLike(Func):
|
||||||
|
@ -3476,7 +3558,7 @@ def maybe_parse(
|
||||||
sql_or_expression: str | Expression,
|
sql_or_expression: str | Expression,
|
||||||
*,
|
*,
|
||||||
into: t.Optional[IntoType] = None,
|
into: t.Optional[IntoType] = None,
|
||||||
dialect: t.Optional[str] = None,
|
dialect: DialectType = None,
|
||||||
prefix: t.Optional[str] = None,
|
prefix: t.Optional[str] = None,
|
||||||
**opts,
|
**opts,
|
||||||
) -> Expression:
|
) -> Expression:
|
||||||
|
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
|
|
||||||
|
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
|
||||||
|
|
||||||
|
|
||||||
|
def to_interval(interval: str | Literal) -> Interval:
|
||||||
|
"""Builds an interval expression from a string like '1 day' or '5 months'."""
|
||||||
|
if isinstance(interval, Literal):
|
||||||
|
if not interval.is_string:
|
||||||
|
raise ValueError("Invalid interval string.")
|
||||||
|
|
||||||
|
interval = interval.this
|
||||||
|
|
||||||
|
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
|
||||||
|
|
||||||
|
if not interval_parts:
|
||||||
|
raise ValueError("Invalid interval string.")
|
||||||
|
|
||||||
|
return Interval(
|
||||||
|
this=Literal.string(interval_parts.group(1)),
|
||||||
|
unit=Var(this=interval_parts.group(2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@t.overload
|
@t.overload
|
||||||
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
def to_table(sql_path: str | Table, **kwargs) -> Table:
|
||||||
...
|
...
|
||||||
|
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
||||||
def subquery(expression, alias=None, dialect=None, **opts):
|
def subquery(expression, alias=None, dialect=None, **opts):
|
||||||
"""
|
"""
|
||||||
Build a subquery expression.
|
Build a subquery expression.
|
||||||
Expample:
|
|
||||||
|
Example:
|
||||||
>>> subquery('select x from tbl', 'bar').select('x').sql()
|
>>> subquery('select x from tbl', 'bar').select('x').sql()
|
||||||
'SELECT x FROM (SELECT x FROM tbl) AS bar'
|
'SELECT x FROM (SELECT x FROM tbl) AS bar'
|
||||||
|
|
||||||
|
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
|
||||||
def column(col, table=None, quoted=None) -> Column:
|
def column(col, table=None, quoted=None) -> Column:
|
||||||
"""
|
"""
|
||||||
Build a Column.
|
Build a Column.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
col (str | Expression): column name
|
col (str | Expression): column name
|
||||||
table (str | Expression): table name
|
table (str | Expression): table name
|
||||||
|
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
|
||||||
|
"""Cast an expression to a data type.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> cast('x + 1', 'int').sql()
|
||||||
|
'CAST(x + 1 AS INT)'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: The expression to cast.
|
||||||
|
to: The datatype to cast to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A cast node.
|
||||||
|
"""
|
||||||
|
expression = maybe_parse(expression, **opts)
|
||||||
|
return Cast(this=expression, to=DataType.build(to, **opts))
|
||||||
|
|
||||||
|
|
||||||
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
|
||||||
"""Build a Table.
|
"""Build a Table.
|
||||||
|
|
||||||
|
@ -4137,7 +4261,7 @@ def values(
|
||||||
types = list(columns.values())
|
types = list(columns.values())
|
||||||
expressions[0].set(
|
expressions[0].set(
|
||||||
"expressions",
|
"expressions",
|
||||||
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
|
[cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
|
||||||
)
|
)
|
||||||
return Values(
|
return Values(
|
||||||
expressions=expressions,
|
expressions=expressions,
|
||||||
|
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
|
||||||
return expression.transform(_expand, copy=copy)
|
return expression.transform(_expand, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
|
def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
|
||||||
"""
|
"""
|
||||||
Returns a Func expression.
|
Returns a Func expression.
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,7 @@ class Generator:
|
||||||
exp.VolatilityProperty: lambda self, e: e.name,
|
exp.VolatilityProperty: lambda self, e: e.name,
|
||||||
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
|
||||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||||
|
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
|
||||||
|
@ -75,6 +76,9 @@ class Generator:
|
||||||
# Whether or not null ordering is supported in order by
|
# Whether or not null ordering is supported in order by
|
||||||
NULL_ORDERING_SUPPORTED = True
|
NULL_ORDERING_SUPPORTED = True
|
||||||
|
|
||||||
|
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
|
||||||
|
LOCKING_READS_SUPPORTED = False
|
||||||
|
|
||||||
# Always do union distinct or union all
|
# Always do union distinct or union all
|
||||||
EXPLICIT_UNION = False
|
EXPLICIT_UNION = False
|
||||||
|
|
||||||
|
@ -99,34 +103,42 @@ class Generator:
|
||||||
|
|
||||||
STRUCT_DELIMITER = ("<", ">")
|
STRUCT_DELIMITER = ("<", ">")
|
||||||
|
|
||||||
BEFORE_PROPERTIES = {
|
PROPERTIES_LOCATION = {
|
||||||
exp.FallbackProperty,
|
exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.WithJournalTableProperty,
|
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
|
||||||
exp.LogProperty,
|
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.JournalProperty,
|
exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.AfterJournalProperty,
|
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.ChecksumProperty,
|
exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.FreespaceProperty,
|
exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.MergeBlockRatioProperty,
|
exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.DataBlocksizeProperty,
|
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
|
||||||
exp.BlockCompressionProperty,
|
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.IsolatedLoadingProperty,
|
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
}
|
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
ROOT_PROPERTIES = {
|
exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.ReturnsProperty,
|
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||||
exp.LanguageProperty,
|
exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.DistStyleProperty,
|
exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.DistKeyProperty,
|
exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.SortKeyProperty,
|
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.LikeProperty,
|
exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
}
|
exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
WITH_PROPERTIES = {
|
exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
exp.Property,
|
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||||
exp.FileFormatProperty,
|
exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||||
exp.PartitionedByProperty,
|
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
exp.TableFormatProperty,
|
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||||
|
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
|
||||||
|
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
|
||||||
|
exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
|
||||||
}
|
}
|
||||||
|
|
||||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
|
||||||
|
@ -284,10 +296,10 @@ class Generator:
|
||||||
)
|
)
|
||||||
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
|
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
|
||||||
|
|
||||||
def no_identify(self, func: t.Callable[[], str]) -> str:
|
def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
|
||||||
original = self.identify
|
original = self.identify
|
||||||
self.identify = False
|
self.identify = False
|
||||||
result = func()
|
result = func(*args, **kwargs)
|
||||||
self.identify = original
|
self.identify = original
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -455,19 +467,33 @@ class Generator:
|
||||||
|
|
||||||
def create_sql(self, expression: exp.Create) -> str:
|
def create_sql(self, expression: exp.Create) -> str:
|
||||||
kind = self.sql(expression, "kind").upper()
|
kind = self.sql(expression, "kind").upper()
|
||||||
has_before_properties = expression.args.get("properties")
|
properties = expression.args.get("properties")
|
||||||
has_before_properties = (
|
properties_exp = expression.copy()
|
||||||
has_before_properties.args.get("before") if has_before_properties else None
|
properties_locs = self.locate_properties(properties) if properties else {}
|
||||||
|
if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
|
||||||
|
exp.Properties.Location.POST_SCHEMA_WITH
|
||||||
|
):
|
||||||
|
properties_exp.set(
|
||||||
|
"properties",
|
||||||
|
exp.Properties(
|
||||||
|
expressions=[
|
||||||
|
*properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
|
||||||
|
*properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
|
||||||
|
]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if kind == "TABLE" and has_before_properties:
|
if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
|
||||||
this_name = self.sql(expression.this, "this")
|
this_name = self.sql(expression.this, "this")
|
||||||
this_properties = self.sql(expression, "properties")
|
this_properties = self.properties(
|
||||||
|
exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
|
||||||
|
wrapped=False,
|
||||||
|
)
|
||||||
this_schema = f"({self.expressions(expression.this)})"
|
this_schema = f"({self.expressions(expression.this)})"
|
||||||
this = f"{this_name}, {this_properties} {this_schema}"
|
this = f"{this_name}, {this_properties} {this_schema}"
|
||||||
properties = ""
|
properties_sql = ""
|
||||||
else:
|
else:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
properties = self.sql(expression, "properties")
|
properties_sql = self.sql(properties_exp, "properties")
|
||||||
begin = " BEGIN" if expression.args.get("begin") else ""
|
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
|
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
|
||||||
|
@ -514,11 +540,31 @@ class Generator:
|
||||||
if index.args.get("columns")
|
if index.args.get("columns")
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
if index.args.get("primary") and properties_locs.get(
|
||||||
|
exp.Properties.Location.POST_INDEX
|
||||||
|
):
|
||||||
|
postindex_props_sql = self.properties(
|
||||||
|
exp.Properties(
|
||||||
|
expressions=properties_locs[exp.Properties.Location.POST_INDEX]
|
||||||
|
),
|
||||||
|
wrapped=False,
|
||||||
|
)
|
||||||
|
ind_columns = f"{ind_columns} {postindex_props_sql}"
|
||||||
|
|
||||||
indexes_sql.append(
|
indexes_sql.append(
|
||||||
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
|
||||||
)
|
)
|
||||||
index_sql = "".join(indexes_sql)
|
index_sql = "".join(indexes_sql)
|
||||||
|
|
||||||
|
postcreate_props_sql = ""
|
||||||
|
if properties_locs.get(exp.Properties.Location.POST_CREATE):
|
||||||
|
postcreate_props_sql = self.properties(
|
||||||
|
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
|
||||||
|
sep=" ",
|
||||||
|
prefix=" ",
|
||||||
|
wrapped=False,
|
||||||
|
)
|
||||||
|
|
||||||
modifiers = "".join(
|
modifiers = "".join(
|
||||||
(
|
(
|
||||||
replace,
|
replace,
|
||||||
|
@ -531,6 +577,7 @@ class Generator:
|
||||||
multiset,
|
multiset,
|
||||||
global_temporary,
|
global_temporary,
|
||||||
volatile,
|
volatile,
|
||||||
|
postcreate_props_sql,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
no_schema_binding = (
|
no_schema_binding = (
|
||||||
|
@ -539,7 +586,7 @@ class Generator:
|
||||||
|
|
||||||
post_expression_modifiers = "".join((data, statistics, no_primary_index))
|
post_expression_modifiers = "".join((data, statistics, no_primary_index))
|
||||||
|
|
||||||
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
|
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
|
||||||
return self.prepend_ctes(expression, expression_sql)
|
return self.prepend_ctes(expression, expression_sql)
|
||||||
|
|
||||||
def describe_sql(self, expression: exp.Describe) -> str:
|
def describe_sql(self, expression: exp.Describe) -> str:
|
||||||
|
@ -665,24 +712,19 @@ class Generator:
|
||||||
return f"PARTITION({self.expressions(expression)})"
|
return f"PARTITION({self.expressions(expression)})"
|
||||||
|
|
||||||
def properties_sql(self, expression: exp.Properties) -> str:
|
def properties_sql(self, expression: exp.Properties) -> str:
|
||||||
before_properties = []
|
|
||||||
root_properties = []
|
root_properties = []
|
||||||
with_properties = []
|
with_properties = []
|
||||||
|
|
||||||
for p in expression.expressions:
|
for p in expression.expressions:
|
||||||
p_class = p.__class__
|
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||||
if p_class in self.BEFORE_PROPERTIES:
|
if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||||
before_properties.append(p)
|
|
||||||
elif p_class in self.WITH_PROPERTIES:
|
|
||||||
with_properties.append(p)
|
with_properties.append(p)
|
||||||
elif p_class in self.ROOT_PROPERTIES:
|
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||||
root_properties.append(p)
|
root_properties.append(p)
|
||||||
|
|
||||||
return (
|
return self.root_properties(
|
||||||
self.properties(exp.Properties(expressions=before_properties), before=True)
|
exp.Properties(expressions=root_properties)
|
||||||
+ self.root_properties(exp.Properties(expressions=root_properties))
|
) + self.with_properties(exp.Properties(expressions=with_properties))
|
||||||
+ self.with_properties(exp.Properties(expressions=with_properties))
|
|
||||||
)
|
|
||||||
|
|
||||||
def root_properties(self, properties: exp.Properties) -> str:
|
def root_properties(self, properties: exp.Properties) -> str:
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
|
@ -695,17 +737,41 @@ class Generator:
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
sep: str = ", ",
|
sep: str = ", ",
|
||||||
suffix: str = "",
|
suffix: str = "",
|
||||||
before: bool = False,
|
wrapped: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
if properties.expressions:
|
if properties.expressions:
|
||||||
expressions = self.expressions(properties, sep=sep, indent=False)
|
expressions = self.expressions(properties, sep=sep, indent=False)
|
||||||
expressions = expressions if before else self.wrap(expressions)
|
expressions = self.wrap(expressions) if wrapped else expressions
|
||||||
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def with_properties(self, properties: exp.Properties) -> str:
|
def with_properties(self, properties: exp.Properties) -> str:
|
||||||
return self.properties(properties, prefix=self.seg("WITH"))
|
return self.properties(properties, prefix=self.seg("WITH"))
|
||||||
|
|
||||||
|
def locate_properties(
|
||||||
|
self, properties: exp.Properties
|
||||||
|
) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
|
||||||
|
properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
|
||||||
|
key: [] for key in exp.Properties.Location
|
||||||
|
}
|
||||||
|
|
||||||
|
for p in properties.expressions:
|
||||||
|
p_loc = self.PROPERTIES_LOCATION[p.__class__]
|
||||||
|
if p_loc == exp.Properties.Location.PRE_SCHEMA:
|
||||||
|
properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
|
||||||
|
elif p_loc == exp.Properties.Location.POST_INDEX:
|
||||||
|
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
|
||||||
|
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
|
||||||
|
properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
|
||||||
|
elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
|
||||||
|
properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
|
||||||
|
elif p_loc == exp.Properties.Location.POST_CREATE:
|
||||||
|
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
|
||||||
|
elif p_loc == exp.Properties.Location.UNSUPPORTED:
|
||||||
|
self.unsupported(f"Unsupported property {p.key}")
|
||||||
|
|
||||||
|
return properties_locs
|
||||||
|
|
||||||
def property_sql(self, expression: exp.Property) -> str:
|
def property_sql(self, expression: exp.Property) -> str:
|
||||||
property_cls = expression.__class__
|
property_cls = expression.__class__
|
||||||
if property_cls == exp.Property:
|
if property_cls == exp.Property:
|
||||||
|
@ -713,7 +779,7 @@ class Generator:
|
||||||
|
|
||||||
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
|
||||||
if not property_name:
|
if not property_name:
|
||||||
self.unsupported(f"Unsupported property {property_name}")
|
self.unsupported(f"Unsupported property {expression.key}")
|
||||||
|
|
||||||
return f"{property_name}={self.sql(expression, 'this')}"
|
return f"{property_name}={self.sql(expression, 'this')}"
|
||||||
|
|
||||||
|
@ -975,7 +1041,7 @@ class Generator:
|
||||||
rollup = self.expressions(expression, key="rollup", indent=False)
|
rollup = self.expressions(expression, key="rollup", indent=False)
|
||||||
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
|
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
|
||||||
|
|
||||||
return f"{group_by}{grouping_sets}{cube}{rollup}"
|
return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
|
||||||
|
|
||||||
def having_sql(self, expression: exp.Having) -> str:
|
def having_sql(self, expression: exp.Having) -> str:
|
||||||
this = self.indent(self.sql(expression, "this"))
|
this = self.indent(self.sql(expression, "this"))
|
||||||
|
@ -1015,7 +1081,7 @@ class Generator:
|
||||||
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
|
||||||
args = self.expressions(expression, flat=True)
|
args = self.expressions(expression, flat=True)
|
||||||
args = f"({args})" if len(args.split(",")) > 1 else args
|
args = f"({args})" if len(args.split(",")) > 1 else args
|
||||||
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
|
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
|
||||||
|
|
||||||
def lateral_sql(self, expression: exp.Lateral) -> str:
|
def lateral_sql(self, expression: exp.Lateral) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
|
@ -1043,6 +1109,14 @@ class Generator:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
|
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
|
||||||
|
|
||||||
|
def lock_sql(self, expression: exp.Lock) -> str:
|
||||||
|
if self.LOCKING_READS_SUPPORTED:
|
||||||
|
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
|
||||||
|
return self.seg(f"FOR {lock_type}")
|
||||||
|
|
||||||
|
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
|
||||||
|
return ""
|
||||||
|
|
||||||
def literal_sql(self, expression: exp.Literal) -> str:
|
def literal_sql(self, expression: exp.Literal) -> str:
|
||||||
text = expression.this or ""
|
text = expression.this or ""
|
||||||
if expression.is_string:
|
if expression.is_string:
|
||||||
|
@ -1163,6 +1237,7 @@ class Generator:
|
||||||
self.sql(expression, "order"),
|
self.sql(expression, "order"),
|
||||||
self.sql(expression, "limit"),
|
self.sql(expression, "limit"),
|
||||||
self.sql(expression, "offset"),
|
self.sql(expression, "offset"),
|
||||||
|
self.sql(expression, "lock"),
|
||||||
sep="",
|
sep="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1773,7 +1848,7 @@ class Generator:
|
||||||
|
|
||||||
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
|
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
expressions = self.no_identify(lambda: self.expressions(expression))
|
expressions = self.no_identify(self.expressions, expression)
|
||||||
expressions = (
|
expressions = (
|
||||||
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
|
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize
|
||||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from sqlglot.dialects.dialect import DialectType
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Node:
|
class Node:
|
||||||
|
@ -36,7 +39,7 @@ def lineage(
|
||||||
schema: t.Optional[t.Dict | Schema] = None,
|
schema: t.Optional[t.Dict | Schema] = None,
|
||||||
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||||
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
|
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
|
||||||
dialect: t.Optional[str] = None,
|
dialect: DialectType = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Build the lineage graph for a column of a SQL query.
|
"""Build the lineage graph for a column of a SQL query.
|
||||||
|
|
||||||
|
@ -126,7 +129,7 @@ class LineageHTML:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
node: Node,
|
node: Node,
|
||||||
dialect: t.Optional[str] = None,
|
dialect: DialectType = None,
|
||||||
imports: bool = True,
|
imports: bool = True,
|
||||||
**opts: t.Any,
|
**opts: t.Any,
|
||||||
):
|
):
|
||||||
|
|
|
@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken):
|
||||||
taken[alias] = scope
|
taken[alias] = scope
|
||||||
|
|
||||||
# Try to maintain the selections
|
# Try to maintain the selections
|
||||||
expressions = scope.expression.args.get("expressions")
|
expressions = scope.selects
|
||||||
selects = [
|
selects = [
|
||||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
|
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
|
||||||
for e in expressions
|
for e in expressions
|
||||||
|
|
|
@ -300,7 +300,7 @@ class Scope:
|
||||||
list[exp.Expression]: expressions
|
list[exp.Expression]: expressions
|
||||||
"""
|
"""
|
||||||
if isinstance(self.expression, exp.Union):
|
if isinstance(self.expression, exp.Union):
|
||||||
return []
|
return self.expression.unnest().selects
|
||||||
return self.expression.selects
|
return self.expression.selects
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -456,8 +456,10 @@ def extract_interval(interval):
|
||||||
|
|
||||||
|
|
||||||
def date_literal(date):
|
def date_literal(date):
|
||||||
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
|
return exp.cast(
|
||||||
return exp.Cast(this=exp.Literal.string(date), to=expr_type)
|
exp.Literal.string(date),
|
||||||
|
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def boolean_literal(condition):
|
def boolean_literal(condition):
|
||||||
|
|
|
@ -80,6 +80,7 @@ class Parser(metaclass=_Parser):
|
||||||
length=exp.Literal.number(10),
|
length=exp.Literal.number(10),
|
||||||
),
|
),
|
||||||
"VAR_MAP": parse_var_map,
|
"VAR_MAP": parse_var_map,
|
||||||
|
"IFNULL": exp.Coalesce.from_arg_list,
|
||||||
}
|
}
|
||||||
|
|
||||||
NO_PAREN_FUNCTIONS = {
|
NO_PAREN_FUNCTIONS = {
|
||||||
|
@ -567,6 +568,8 @@ class Parser(metaclass=_Parser):
|
||||||
default=self._prev.text.upper() == "DEFAULT"
|
default=self._prev.text.upper() == "DEFAULT"
|
||||||
),
|
),
|
||||||
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
|
||||||
|
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
|
||||||
|
"DEFINER": lambda self: self._parse_definer(),
|
||||||
}
|
}
|
||||||
|
|
||||||
CONSTRAINT_PARSERS = {
|
CONSTRAINT_PARSERS = {
|
||||||
|
@ -608,6 +611,7 @@ class Parser(metaclass=_Parser):
|
||||||
"order": lambda self: self._parse_order(),
|
"order": lambda self: self._parse_order(),
|
||||||
"limit": lambda self: self._parse_limit(),
|
"limit": lambda self: self._parse_limit(),
|
||||||
"offset": lambda self: self._parse_offset(),
|
"offset": lambda self: self._parse_offset(),
|
||||||
|
"lock": lambda self: self._parse_lock(),
|
||||||
}
|
}
|
||||||
|
|
||||||
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
SHOW_PARSERS: t.Dict[str, t.Callable] = {}
|
||||||
|
@ -850,7 +854,7 @@ class Parser(metaclass=_Parser):
|
||||||
self.raise_error(error_message)
|
self.raise_error(error_message)
|
||||||
|
|
||||||
def _find_sql(self, start: Token, end: Token) -> str:
|
def _find_sql(self, start: Token, end: Token) -> str:
|
||||||
return self.sql[self._find_token(start) : self._find_token(end)]
|
return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
|
||||||
|
|
||||||
def _find_token(self, token: Token) -> int:
|
def _find_token(self, token: Token) -> int:
|
||||||
line = 1
|
line = 1
|
||||||
|
@ -901,6 +905,7 @@ class Parser(metaclass=_Parser):
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
|
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
|
||||||
|
start = self._prev
|
||||||
temporary = self._match(TokenType.TEMPORARY)
|
temporary = self._match(TokenType.TEMPORARY)
|
||||||
materialized = self._match(TokenType.MATERIALIZED)
|
materialized = self._match(TokenType.MATERIALIZED)
|
||||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||||
|
@ -908,8 +913,7 @@ class Parser(metaclass=_Parser):
|
||||||
if default_kind:
|
if default_kind:
|
||||||
kind = default_kind
|
kind = default_kind
|
||||||
else:
|
else:
|
||||||
self.raise_error(f"Expected {self.CREATABLES}")
|
return self._parse_as_command(start)
|
||||||
return None
|
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Drop,
|
exp.Drop,
|
||||||
|
@ -929,6 +933,7 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_create(self) -> t.Optional[exp.Expression]:
|
def _parse_create(self) -> t.Optional[exp.Expression]:
|
||||||
|
start = self._prev
|
||||||
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
|
||||||
set_ = self._match(TokenType.SET) # Teradata
|
set_ = self._match(TokenType.SET) # Teradata
|
||||||
multiset = self._match_text_seq("MULTISET") # Teradata
|
multiset = self._match_text_seq("MULTISET") # Teradata
|
||||||
|
@ -943,16 +948,19 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
|
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
|
||||||
self._match(TokenType.TABLE)
|
self._match(TokenType.TABLE)
|
||||||
|
|
||||||
|
properties = None
|
||||||
create_token = self._match_set(self.CREATABLES) and self._prev
|
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||||
|
|
||||||
if not create_token:
|
if not create_token:
|
||||||
self.raise_error(f"Expected {self.CREATABLES}")
|
properties = self._parse_properties()
|
||||||
return None
|
create_token = self._match_set(self.CREATABLES) and self._prev
|
||||||
|
|
||||||
|
if not properties or not create_token:
|
||||||
|
return self._parse_as_command(start)
|
||||||
|
|
||||||
exists = self._parse_exists(not_=True)
|
exists = self._parse_exists(not_=True)
|
||||||
this = None
|
this = None
|
||||||
expression = None
|
expression = None
|
||||||
properties = None
|
|
||||||
data = None
|
data = None
|
||||||
statistics = None
|
statistics = None
|
||||||
no_primary_index = None
|
no_primary_index = None
|
||||||
|
@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser):
|
||||||
indexes = []
|
indexes = []
|
||||||
while True:
|
while True:
|
||||||
index = self._parse_create_table_index()
|
index = self._parse_create_table_index()
|
||||||
|
|
||||||
|
# post index PARTITION BY property
|
||||||
|
if self._match(TokenType.PARTITION_BY, advance=False):
|
||||||
|
if properties:
|
||||||
|
properties.expressions.append(self._parse_property())
|
||||||
|
else:
|
||||||
|
properties = self._parse_properties()
|
||||||
|
|
||||||
if not index:
|
if not index:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_property_before(self) -> t.Optional[exp.Expression]:
|
def _parse_property_before(self) -> t.Optional[exp.Expression]:
|
||||||
|
self._match(TokenType.COMMA)
|
||||||
|
|
||||||
|
# parsers look to _prev for no/dual/default, so need to consume first
|
||||||
self._match_text_seq("NO")
|
self._match_text_seq("NO")
|
||||||
self._match_text_seq("DUAL")
|
self._match_text_seq("DUAL")
|
||||||
self._match_text_seq("DEFAULT")
|
self._match_text_seq("DEFAULT")
|
||||||
|
@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser):
|
||||||
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
|
||||||
return self._parse_sortkey(compound=True)
|
return self._parse_sortkey(compound=True)
|
||||||
|
|
||||||
|
if self._match_text_seq("SQL", "SECURITY"):
|
||||||
|
return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
|
||||||
|
|
||||||
assignment = self._match_pair(
|
assignment = self._match_pair(
|
||||||
TokenType.VAR, TokenType.EQ, advance=False
|
TokenType.VAR, TokenType.EQ, advance=False
|
||||||
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
||||||
|
@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if before:
|
if before:
|
||||||
self._match(TokenType.COMMA)
|
|
||||||
identified_property = self._parse_property_before()
|
identified_property = self._parse_property_before()
|
||||||
else:
|
else:
|
||||||
identified_property = self._parse_property()
|
identified_property = self._parse_property()
|
||||||
|
@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser):
|
||||||
properties.append(p)
|
properties.append(p)
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
return self.expression(exp.Properties, expressions=properties, before=before)
|
return self.expression(exp.Properties, expressions=properties)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
return self._parse_withisolatedloading()
|
return self._parse_withisolatedloading()
|
||||||
|
|
||||||
|
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
|
||||||
|
def _parse_definer(self) -> t.Optional[exp.Expression]:
|
||||||
|
self._match(TokenType.EQ)
|
||||||
|
|
||||||
|
user = self._parse_id_var()
|
||||||
|
self._match(TokenType.PARAMETER)
|
||||||
|
host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text)
|
||||||
|
|
||||||
|
if not user or not host:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return exp.DefinerProperty(this=f"{user}@{host}")
|
||||||
|
|
||||||
def _parse_withjournaltable(self) -> exp.Expression:
|
def _parse_withjournaltable(self) -> exp.Expression:
|
||||||
self._match_text_seq("WITH", "JOURNAL", "TABLE")
|
self._match_text_seq("WITH", "JOURNAL", "TABLE")
|
||||||
self._match(TokenType.EQ)
|
self._match(TokenType.EQ)
|
||||||
|
@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser):
|
||||||
paren += 1
|
paren += 1
|
||||||
if self._curr.token_type == TokenType.R_PAREN:
|
if self._curr.token_type == TokenType.R_PAREN:
|
||||||
paren -= 1
|
paren -= 1
|
||||||
|
end = self._prev
|
||||||
self._advance()
|
self._advance()
|
||||||
if paren > 0:
|
if paren > 0:
|
||||||
self.raise_error("Expecting )", self._curr)
|
self.raise_error("Expecting )", self._curr)
|
||||||
if not self._curr:
|
|
||||||
self.raise_error("Expecting pattern", self._curr)
|
|
||||||
end = self._prev
|
|
||||||
pattern = exp.Var(this=self._find_sql(start, end))
|
pattern = exp.Var(this=self._find_sql(start, end))
|
||||||
else:
|
else:
|
||||||
pattern = None
|
pattern = None
|
||||||
|
@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser):
|
||||||
expressions = self._parse_csv(self._parse_conjunction)
|
expressions = self._parse_csv(self._parse_conjunction)
|
||||||
grouping_sets = self._parse_grouping_sets()
|
grouping_sets = self._parse_grouping_sets()
|
||||||
|
|
||||||
|
self._match(TokenType.COMMA)
|
||||||
with_ = self._match(TokenType.WITH)
|
with_ = self._match(TokenType.WITH)
|
||||||
cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars())
|
cube = self._match(TokenType.CUBE) and (
|
||||||
rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars())
|
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._match(TokenType.COMMA)
|
||||||
|
rollup = self._match(TokenType.ROLLUP) and (
|
||||||
|
with_ or self._parse_wrapped_csv(self._parse_column)
|
||||||
|
)
|
||||||
|
|
||||||
return self.expression(
|
return self.expression(
|
||||||
exp.Group,
|
exp.Group,
|
||||||
|
@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser):
|
||||||
self._match_set((TokenType.ROW, TokenType.ROWS))
|
self._match_set((TokenType.ROW, TokenType.ROWS))
|
||||||
return self.expression(exp.Offset, this=this, expression=count)
|
return self.expression(exp.Offset, this=this, expression=count)
|
||||||
|
|
||||||
|
def _parse_lock(self) -> t.Optional[exp.Expression]:
|
||||||
|
if self._match_text_seq("FOR", "UPDATE"):
|
||||||
|
return self.expression(exp.Lock, update=True)
|
||||||
|
if self._match_text_seq("FOR", "SHARE"):
|
||||||
|
return self.expression(exp.Lock, update=False)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||||
if not self._match_set(self.SET_OPERATIONS):
|
if not self._match_set(self.SET_OPERATIONS):
|
||||||
return this
|
return this
|
||||||
|
@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser):
|
||||||
maybe_func = True
|
maybe_func = True
|
||||||
|
|
||||||
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||||
return exp.DataType(
|
this = exp.DataType(
|
||||||
this=exp.DataType.Type.ARRAY,
|
this=exp.DataType.Type.ARRAY,
|
||||||
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
|
||||||
nested=True,
|
nested=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||||
|
this = exp.DataType(
|
||||||
|
this=exp.DataType.Type.ARRAY,
|
||||||
|
expressions=[this],
|
||||||
|
nested=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return this
|
||||||
|
|
||||||
if self._match(TokenType.L_BRACKET):
|
if self._match(TokenType.L_BRACKET):
|
||||||
self._retreat(index)
|
self._retreat(index)
|
||||||
return None
|
return None
|
||||||
|
@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser):
|
||||||
self.raise_error("Expected type")
|
self.raise_error("Expected type")
|
||||||
elif op:
|
elif op:
|
||||||
self._advance()
|
self._advance()
|
||||||
field = exp.Literal.string(self._prev.text)
|
value = self._prev.text
|
||||||
|
field = (
|
||||||
|
exp.Literal.number(value)
|
||||||
|
if self._prev.token_type == TokenType.NUMBER
|
||||||
|
else exp.Literal.string(value)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
field = self._parse_star() or self._parse_function() or self._parse_id_var()
|
||||||
|
|
||||||
|
@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._curr:
|
if not self._curr:
|
||||||
break
|
break
|
||||||
|
|
||||||
if self._match_text_seq("NOT", "ENFORCED"):
|
if self._match(TokenType.ON):
|
||||||
|
action = None
|
||||||
|
on = self._advance_any() and self._prev.text
|
||||||
|
|
||||||
|
if self._match(TokenType.NO_ACTION):
|
||||||
|
action = "NO ACTION"
|
||||||
|
elif self._match(TokenType.CASCADE):
|
||||||
|
action = "CASCADE"
|
||||||
|
elif self._match_pair(TokenType.SET, TokenType.NULL):
|
||||||
|
action = "SET NULL"
|
||||||
|
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
||||||
|
action = "SET DEFAULT"
|
||||||
|
else:
|
||||||
|
self.raise_error("Invalid key constraint")
|
||||||
|
|
||||||
|
options.append(f"ON {on} {action}")
|
||||||
|
elif self._match_text_seq("NOT", "ENFORCED"):
|
||||||
options.append("NOT ENFORCED")
|
options.append("NOT ENFORCED")
|
||||||
elif self._match_text_seq("DEFERRABLE"):
|
elif self._match_text_seq("DEFERRABLE"):
|
||||||
options.append("DEFERRABLE")
|
options.append("DEFERRABLE")
|
||||||
|
@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser):
|
||||||
options.append("NORELY")
|
options.append("NORELY")
|
||||||
elif self._match_text_seq("MATCH", "FULL"):
|
elif self._match_text_seq("MATCH", "FULL"):
|
||||||
options.append("MATCH FULL")
|
options.append("MATCH FULL")
|
||||||
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
|
|
||||||
options.append("ON UPDATE NO ACTION")
|
|
||||||
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
|
|
||||||
options.append("ON DELETE NO ACTION")
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser):
|
||||||
prefix += self._prev.text
|
prefix += self._prev.text
|
||||||
|
|
||||||
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||||
return exp.Identifier(this=prefix + self._prev.text, quoted=False)
|
quoted = self._prev.token_type == TokenType.STRING
|
||||||
|
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_string(self) -> t.Optional[exp.Expression]:
|
def _parse_string(self) -> t.Optional[exp.Expression]:
|
||||||
|
@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser):
|
||||||
def _parse_set(self) -> exp.Expression:
|
def _parse_set(self) -> exp.Expression:
|
||||||
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
|
||||||
|
|
||||||
|
def _parse_as_command(self, start: Token) -> exp.Command:
|
||||||
|
while self._curr:
|
||||||
|
self._advance()
|
||||||
|
return exp.Command(this=self._find_sql(start, self._prev))
|
||||||
|
|
||||||
def _find_parser(
|
def _find_parser(
|
||||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||||
) -> t.Optional[t.Callable]:
|
) -> t.Optional[t.Callable]:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from sqlglot.dataframe.sql.types import StructType
|
from sqlglot.dataframe.sql.types import StructType
|
||||||
|
from sqlglot.dialects.dialect import DialectType
|
||||||
|
|
||||||
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
|
||||||
|
|
||||||
|
@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
||||||
self,
|
self,
|
||||||
schema: t.Optional[t.Dict] = None,
|
schema: t.Optional[t.Dict] = None,
|
||||||
visible: t.Optional[t.Dict] = None,
|
visible: t.Optional[t.Dict] = None,
|
||||||
dialect: t.Optional[str] = None,
|
dialect: DialectType = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dialect = dialect
|
self.dialect = dialect
|
||||||
self.visible = visible or {}
|
self.visible = visible or {}
|
||||||
|
|
|
@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"STRING": TokenType.TEXT,
|
"STRING": TokenType.TEXT,
|
||||||
"TEXT": TokenType.TEXT,
|
"TEXT": TokenType.TEXT,
|
||||||
"CLOB": TokenType.TEXT,
|
"CLOB": TokenType.TEXT,
|
||||||
|
"LONGVARCHAR": TokenType.TEXT,
|
||||||
"BINARY": TokenType.BINARY,
|
"BINARY": TokenType.BINARY,
|
||||||
"BLOB": TokenType.VARBINARY,
|
"BLOB": TokenType.VARBINARY,
|
||||||
"BYTEA": TokenType.VARBINARY,
|
"BYTEA": TokenType.VARBINARY,
|
||||||
|
|
|
@ -170,7 +170,7 @@ class TestBigQuery(Validator):
|
||||||
"bigquery": "CURRENT_TIMESTAMP()",
|
"bigquery": "CURRENT_TIMESTAMP()",
|
||||||
"duckdb": "CURRENT_TIMESTAMP()",
|
"duckdb": "CURRENT_TIMESTAMP()",
|
||||||
"postgres": "CURRENT_TIMESTAMP",
|
"postgres": "CURRENT_TIMESTAMP",
|
||||||
"presto": "CURRENT_TIMESTAMP()",
|
"presto": "CURRENT_TIMESTAMP",
|
||||||
"hive": "CURRENT_TIMESTAMP()",
|
"hive": "CURRENT_TIMESTAMP()",
|
||||||
"spark": "CURRENT_TIMESTAMP()",
|
"spark": "CURRENT_TIMESTAMP()",
|
||||||
},
|
},
|
||||||
|
@ -181,7 +181,7 @@ class TestBigQuery(Validator):
|
||||||
"bigquery": "CURRENT_TIMESTAMP()",
|
"bigquery": "CURRENT_TIMESTAMP()",
|
||||||
"duckdb": "CURRENT_TIMESTAMP()",
|
"duckdb": "CURRENT_TIMESTAMP()",
|
||||||
"postgres": "CURRENT_TIMESTAMP",
|
"postgres": "CURRENT_TIMESTAMP",
|
||||||
"presto": "CURRENT_TIMESTAMP()",
|
"presto": "CURRENT_TIMESTAMP",
|
||||||
"hive": "CURRENT_TIMESTAMP()",
|
"hive": "CURRENT_TIMESTAMP()",
|
||||||
"spark": "CURRENT_TIMESTAMP()",
|
"spark": "CURRENT_TIMESTAMP()",
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
|
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
|
||||||
|
from sqlglot.dialects import Hive
|
||||||
|
|
||||||
|
|
||||||
class Validator(unittest.TestCase):
|
class Validator(unittest.TestCase):
|
||||||
|
@ -67,6 +68,11 @@ class TestDialect(Validator):
|
||||||
self.assertIsNotNone(Dialect.get_or_raise(dialect))
|
self.assertIsNotNone(Dialect.get_or_raise(dialect))
|
||||||
self.assertIsNotNone(Dialect[dialect.value])
|
self.assertIsNotNone(Dialect[dialect.value])
|
||||||
|
|
||||||
|
def test_get_or_raise(self):
|
||||||
|
self.assertEqual(Dialect.get_or_raise(Hive), Hive)
|
||||||
|
self.assertEqual(Dialect.get_or_raise(Hive()), Hive)
|
||||||
|
self.assertEqual(Dialect.get_or_raise("hive"), Hive)
|
||||||
|
|
||||||
def test_cast(self):
|
def test_cast(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CAST(a AS TEXT)",
|
"CAST(a AS TEXT)",
|
||||||
|
@ -280,6 +286,21 @@ class TestDialect(Validator):
|
||||||
write={"oracle": "CAST(a AS NUMBER)"},
|
write={"oracle": "CAST(a AS NUMBER)"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_if_null(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT IFNULL(1, NULL) FROM foo",
|
||||||
|
write={
|
||||||
|
"": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"redshift": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"postgres": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"mysql": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"duckdb": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"spark": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"bigquery": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
"presto": "SELECT COALESCE(1, NULL) FROM foo",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_time(self):
|
def test_time(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
|
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||||
|
@ -287,10 +308,10 @@ class TestDialect(Validator):
|
||||||
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
|
||||||
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
|
||||||
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
|
"presto": "DATE_PARSE(x, '%Y-%m-%dT%T')",
|
||||||
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||||
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
|
||||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
|
||||||
|
@ -356,7 +377,7 @@ class TestDialect(Validator):
|
||||||
write={
|
write={
|
||||||
"duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))",
|
"duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))",
|
||||||
"hive": "UNIX_TIMESTAMP('2020-01-01')",
|
"hive": "UNIX_TIMESTAMP('2020-01-01')",
|
||||||
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))",
|
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -418,7 +439,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"UNIX_TO_STR(x, y)",
|
"UNIX_TO_STR(x, y)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)",
|
"duckdb": "STRFTIME(TO_TIMESTAMP(x), y)",
|
||||||
"hive": "FROM_UNIXTIME(x, y)",
|
"hive": "FROM_UNIXTIME(x, y)",
|
||||||
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
|
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
|
||||||
"starrocks": "FROM_UNIXTIME(x, y)",
|
"starrocks": "FROM_UNIXTIME(x, y)",
|
||||||
|
@ -427,7 +448,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"UNIX_TO_TIME(x)",
|
"UNIX_TO_TIME(x)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
|
"duckdb": "TO_TIMESTAMP(x)",
|
||||||
"hive": "FROM_UNIXTIME(x)",
|
"hive": "FROM_UNIXTIME(x)",
|
||||||
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
|
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
|
||||||
"postgres": "TO_TIMESTAMP(x)",
|
"postgres": "TO_TIMESTAMP(x)",
|
||||||
|
@ -438,7 +459,7 @@ class TestDialect(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"UNIX_TO_TIME_STR(x)",
|
"UNIX_TO_TIME_STR(x)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)",
|
"duckdb": "CAST(TO_TIMESTAMP(x) AS TEXT)",
|
||||||
"hive": "FROM_UNIXTIME(x)",
|
"hive": "FROM_UNIXTIME(x)",
|
||||||
"presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)",
|
"presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)",
|
||||||
},
|
},
|
||||||
|
@ -575,10 +596,10 @@ class TestDialect(Validator):
|
||||||
},
|
},
|
||||||
write={
|
write={
|
||||||
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
|
||||||
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
|
||||||
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
|
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
|
||||||
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
|
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
|
||||||
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)",
|
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)",
|
||||||
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
|
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -709,6 +730,7 @@ class TestDialect(Validator):
|
||||||
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
|
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1381,3 +1403,21 @@ SELECT
|
||||||
"spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
|
"spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_substring(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SUBSTR('123456', 2, 3)",
|
||||||
|
write={
|
||||||
|
"bigquery": "SUBSTR('123456', 2, 3)",
|
||||||
|
"oracle": "SUBSTR('123456', 2, 3)",
|
||||||
|
"postgres": "SUBSTR('123456', 2, 3)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SUBSTRING('123456', 2, 3)",
|
||||||
|
write={
|
||||||
|
"bigquery": "SUBSTRING('123456', 2, 3)",
|
||||||
|
"oracle": "SUBSTR('123456', 2, 3)",
|
||||||
|
"postgres": "SUBSTRING('123456' FROM 2 FOR 3)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class TestDuckDB(Validator):
|
||||||
"EPOCH_MS(x)",
|
"EPOCH_MS(x)",
|
||||||
write={
|
write={
|
||||||
"bigquery": "UNIX_TO_TIME(x / 1000)",
|
"bigquery": "UNIX_TO_TIME(x / 1000)",
|
||||||
"duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))",
|
"duckdb": "TO_TIMESTAMP(x / 1000)",
|
||||||
"presto": "FROM_UNIXTIME(x / 1000)",
|
"presto": "FROM_UNIXTIME(x / 1000)",
|
||||||
"spark": "FROM_UNIXTIME(x / 1000)",
|
"spark": "FROM_UNIXTIME(x / 1000)",
|
||||||
},
|
},
|
||||||
|
@ -41,7 +41,7 @@ class TestDuckDB(Validator):
|
||||||
"STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
"STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
||||||
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
|
"presto": "DATE_FORMAT(x, '%Y-%m-%d %T')",
|
||||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -58,9 +58,10 @@ class TestDuckDB(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TO_TIMESTAMP(x)",
|
"TO_TIMESTAMP(x)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "CAST(x AS TIMESTAMP)",
|
"bigquery": "UNIX_TO_TIME(x)",
|
||||||
"presto": "CAST(x AS TIMESTAMP)",
|
"duckdb": "TO_TIMESTAMP(x)",
|
||||||
"hive": "CAST(x AS TIMESTAMP)",
|
"presto": "FROM_UNIXTIME(x)",
|
||||||
|
"hive": "FROM_UNIXTIME(x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -334,6 +335,14 @@ class TestDuckDB(Validator):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"cast([[1]] as int[][])",
|
||||||
|
write={
|
||||||
|
"duckdb": "CAST(LIST_VALUE(LIST_VALUE(1)) AS INT[][])",
|
||||||
|
"spark": "CAST(ARRAY(ARRAY(1)) AS ARRAY<ARRAY<INT>>)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_bool_or(self):
|
def test_bool_or(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
|
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
|
||||||
|
|
|
@ -259,7 +259,7 @@ class TestHive(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""from_unixtime(x, "yyyy-MM-dd'T'HH")""",
|
"""from_unixtime(x, "yyyy-MM-dd'T'HH")""",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')",
|
"duckdb": "STRFTIME(TO_TIMESTAMP(x), '%Y-%m-%d''T''%H')",
|
||||||
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')",
|
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')",
|
||||||
"hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
|
"hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
|
||||||
"spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
|
"spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
|
||||||
|
@ -269,7 +269,7 @@ class TestHive(Validator):
|
||||||
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
|
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')",
|
"duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')",
|
||||||
"presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')",
|
"presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %T')",
|
||||||
"hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
|
"hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
|
||||||
"spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
|
"spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
|
||||||
},
|
},
|
||||||
|
@ -308,7 +308,7 @@ class TestHive(Validator):
|
||||||
"UNIX_TIMESTAMP(x)",
|
"UNIX_TIMESTAMP(x)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))",
|
"duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))",
|
||||||
"presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))",
|
"presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %T'))",
|
||||||
"hive": "UNIX_TIMESTAMP(x)",
|
"hive": "UNIX_TIMESTAMP(x)",
|
||||||
"spark": "UNIX_TIMESTAMP(x)",
|
"spark": "UNIX_TIMESTAMP(x)",
|
||||||
"": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",
|
"": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",
|
||||||
|
|
|
@ -195,6 +195,26 @@ class TestMySQL(Validator):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_mysql(self):
|
def test_mysql(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a FROM tbl FOR UPDATE",
|
||||||
|
write={
|
||||||
|
"": "SELECT a FROM tbl",
|
||||||
|
"mysql": "SELECT a FROM tbl FOR UPDATE",
|
||||||
|
"oracle": "SELECT a FROM tbl FOR UPDATE",
|
||||||
|
"postgres": "SELECT a FROM tbl FOR UPDATE",
|
||||||
|
"tsql": "SELECT a FROM tbl FOR UPDATE",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT a FROM tbl FOR SHARE",
|
||||||
|
write={
|
||||||
|
"": "SELECT a FROM tbl",
|
||||||
|
"mysql": "SELECT a FROM tbl FOR SHARE",
|
||||||
|
"oracle": "SELECT a FROM tbl FOR SHARE",
|
||||||
|
"postgres": "SELECT a FROM tbl FOR SHARE",
|
||||||
|
"tsql": "SELECT a FROM tbl FOR SHARE",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
|
||||||
write={
|
write={
|
||||||
|
|
|
@ -112,6 +112,22 @@ class TestPostgres(Validator):
|
||||||
self.validate_identity("x ~ 'y'")
|
self.validate_identity("x ~ 'y'")
|
||||||
self.validate_identity("x ~* 'y'")
|
self.validate_identity("x ~* 'y'")
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"GENERATE_SERIES(a, b, ' 2 days ')",
|
||||||
|
write={
|
||||||
|
"postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)",
|
||||||
|
"presto": "SEQUENCE(a, b, INTERVAL '2' days)",
|
||||||
|
"trino": "SEQUENCE(a, b, INTERVAL '2' days)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')",
|
||||||
|
write={
|
||||||
|
"postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)",
|
||||||
|
"presto": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
|
||||||
|
"trino": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"END WORK AND NO CHAIN",
|
"END WORK AND NO CHAIN",
|
||||||
write={"postgres": "COMMIT AND NO CHAIN"},
|
write={"postgres": "COMMIT AND NO CHAIN"},
|
||||||
|
@ -249,7 +265,7 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"'[1,2,3]'::json->2",
|
"'[1,2,3]'::json->2",
|
||||||
write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"},
|
write={"postgres": "CAST('[1,2,3]' AS JSON) -> 2"},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""'{"a":1,"b":2}'::json->'b'""",
|
"""'{"a":1,"b":2}'::json->'b'""",
|
||||||
|
@ -265,7 +281,7 @@ class TestPostgres(Validator):
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""'[1,2,3]'::json->>2""",
|
"""'[1,2,3]'::json->>2""",
|
||||||
write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"},
|
write={"postgres": "CAST('[1,2,3]' AS JSON) ->> 2"},
|
||||||
)
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"""'{"a":1,"b":2}'::json->>'b'""",
|
"""'{"a":1,"b":2}'::json->>'b'""",
|
||||||
|
|
|
@ -111,7 +111,7 @@ class TestPresto(Validator):
|
||||||
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
|
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
|
||||||
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
|
"presto": "DATE_FORMAT(x, '%Y-%m-%d %T')",
|
||||||
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
||||||
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
|
||||||
},
|
},
|
||||||
|
@ -120,7 +120,7 @@ class TestPresto(Validator):
|
||||||
"DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
|
"DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
|
||||||
write={
|
write={
|
||||||
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
|
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
|
||||||
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
|
"presto": "DATE_PARSE(x, '%Y-%m-%d %T')",
|
||||||
"hive": "CAST(x AS TIMESTAMP)",
|
"hive": "CAST(x AS TIMESTAMP)",
|
||||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
|
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
|
||||||
},
|
},
|
||||||
|
@ -134,6 +134,12 @@ class TestPresto(Validator):
|
||||||
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
|
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"DATE_FORMAT(x, '%T')",
|
||||||
|
write={
|
||||||
|
"hive": "DATE_FORMAT(x, 'HH:mm:ss')",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
|
"DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
|
||||||
write={
|
write={
|
||||||
|
@ -146,7 +152,7 @@ class TestPresto(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"FROM_UNIXTIME(x)",
|
"FROM_UNIXTIME(x)",
|
||||||
write={
|
write={
|
||||||
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
|
"duckdb": "TO_TIMESTAMP(x)",
|
||||||
"presto": "FROM_UNIXTIME(x)",
|
"presto": "FROM_UNIXTIME(x)",
|
||||||
"hive": "FROM_UNIXTIME(x)",
|
"hive": "FROM_UNIXTIME(x)",
|
||||||
"spark": "FROM_UNIXTIME(x)",
|
"spark": "FROM_UNIXTIME(x)",
|
||||||
|
@ -177,11 +183,51 @@ class TestPresto(Validator):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"NOW()",
|
"NOW()",
|
||||||
write={
|
write={
|
||||||
"presto": "CURRENT_TIMESTAMP()",
|
"presto": "CURRENT_TIMESTAMP",
|
||||||
"hive": "CURRENT_TIMESTAMP()",
|
"hive": "CURRENT_TIMESTAMP()",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"DAY_OF_WEEK(timestamp '2012-08-08 01:00')",
|
||||||
|
write={
|
||||||
|
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"DAY_OF_MONTH(timestamp '2012-08-08 01:00')",
|
||||||
|
write={
|
||||||
|
"spark": "DAYOFMONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
"presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"DAY_OF_YEAR(timestamp '2012-08-08 01:00')",
|
||||||
|
write={
|
||||||
|
"spark": "DAYOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
"presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"WEEK_OF_YEAR(timestamp '2012-08-08 01:00')",
|
||||||
|
write={
|
||||||
|
"spark": "WEEKOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
"presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT timestamp '2012-10-31 00:00' AT TIME ZONE 'America/Sao_Paulo'",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')",
|
||||||
|
"presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def test_ddl(self):
|
def test_ddl(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
|
||||||
|
@ -314,6 +360,11 @@ class TestPresto(Validator):
|
||||||
|
|
||||||
def test_presto(self):
|
def test_presto(self):
|
||||||
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
|
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
|
||||||
|
self.validate_identity("SELECT * FROM (VALUES (1))")
|
||||||
|
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
|
||||||
|
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
||||||
|
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
|
||||||
|
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
'SELECT a."b" FROM "foo"',
|
'SELECT a."b" FROM "foo"',
|
||||||
write={
|
write={
|
||||||
|
@ -455,10 +506,6 @@ class TestPresto(Validator):
|
||||||
"spark": UnsupportedError,
|
"spark": UnsupportedError,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.validate_identity("SELECT * FROM (VALUES (1))")
|
|
||||||
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
|
|
||||||
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
|
||||||
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
|
|
||||||
|
|
||||||
def test_encode_decode(self):
|
def test_encode_decode(self):
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
|
@ -529,3 +576,27 @@ class TestPresto(Validator):
|
||||||
"presto": "FROM_HEX(x)",
|
"presto": "FROM_HEX(x)",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_json(self):
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT FROM_JSON('[1,23,456]', 'ARRAY<INT>')",
|
||||||
|
"presto": "SELECT CAST(CAST('[1,23,456]' AS JSON) AS ARRAY(INTEGER))",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.validate_all(
|
||||||
|
"""SELECT CAST(JSON '{"k1":1,"k2":23,"k3":456}' AS MAP(VARCHAR, INTEGER))""",
|
||||||
|
write={
|
||||||
|
"spark": 'SELECT FROM_JSON(\'{"k1":1,"k2":23,"k3":456}\', \'MAP<STRING, INT>\')',
|
||||||
|
"presto": 'SELECT CAST(CAST(\'{"k1":1,"k2":23,"k3":456}\' AS JSON) AS MAP(VARCHAR, INTEGER))',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"SELECT CAST(ARRAY [1, 23, 456] AS JSON)",
|
||||||
|
write={
|
||||||
|
"spark": "SELECT TO_JSON(ARRAY(1, 23, 456))",
|
||||||
|
"presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -212,6 +212,17 @@ TBLPROPERTIES (
|
||||||
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
|
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
|
||||||
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
|
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
|
||||||
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
|
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
|
||||||
|
|
||||||
|
self.validate_all(
|
||||||
|
"AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
write={
|
||||||
|
"trino": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
"duckdb": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
"hive": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
"presto": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
"spark": "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
|
||||||
|
},
|
||||||
|
)
|
||||||
self.validate_all(
|
self.validate_all(
|
||||||
"TRIM('SL', 'SSparkSQLS')", write={"spark": "TRIM('SL' FROM 'SSparkSQLS')"}
|
"TRIM('SL', 'SSparkSQLS')", write={"spark": "TRIM('SL' FROM 'SSparkSQLS')"}
|
||||||
)
|
)
|
||||||
|
|
|
@ -92,3 +92,9 @@ class TestSQLite(Validator):
|
||||||
"sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
|
"sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_longvarchar_dtype(self):
|
||||||
|
self.validate_all(
|
||||||
|
"CREATE TABLE foo (bar LONGVARCHAR)",
|
||||||
|
write={"sqlite": "CREATE TABLE foo (bar TEXT)"},
|
||||||
|
)
|
||||||
|
|
|
@ -21,3 +21,6 @@ class TestTeradata(Validator):
|
||||||
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
|
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_create(self):
|
||||||
|
self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)")
|
||||||
|
|
16
tests/fixtures/identity.sql
vendored
16
tests/fixtures/identity.sql
vendored
|
@ -161,6 +161,7 @@ SELECT 1 FROM test
|
||||||
SELECT * FROM a, b, (SELECT 1) AS c
|
SELECT * FROM a, b, (SELECT 1) AS c
|
||||||
SELECT a FROM test
|
SELECT a FROM test
|
||||||
SELECT 1 AS filter
|
SELECT 1 AS filter
|
||||||
|
SELECT 1 AS "quoted alias"
|
||||||
SELECT SUM(x) AS filter
|
SELECT SUM(x) AS filter
|
||||||
SELECT 1 AS range FROM test
|
SELECT 1 AS range FROM test
|
||||||
SELECT 1 AS count FROM test
|
SELECT 1 AS count FROM test
|
||||||
|
@ -264,7 +265,9 @@ SELECT a FROM test GROUP BY GROUPING SETS (x, ())
|
||||||
SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q)
|
SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q)
|
||||||
SELECT a FROM test GROUP BY CUBE (x)
|
SELECT a FROM test GROUP BY CUBE (x)
|
||||||
SELECT a FROM test GROUP BY ROLLUP (x)
|
SELECT a FROM test GROUP BY ROLLUP (x)
|
||||||
SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z)
|
SELECT t.a FROM test AS t GROUP BY ROLLUP (t.x)
|
||||||
|
SELECT a FROM test GROUP BY GROUPING SETS ((x, y)), ROLLUP (b)
|
||||||
|
SELECT a FROM test GROUP BY CUBE (x), ROLLUP (x, y, z)
|
||||||
SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test
|
SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test
|
||||||
SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END
|
SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END
|
||||||
SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a']
|
SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a']
|
||||||
|
@ -339,7 +342,6 @@ SELECT CAST(a AS ARRAY<INT>) FROM test
|
||||||
SELECT CAST(a AS VARIANT) FROM test
|
SELECT CAST(a AS VARIANT) FROM test
|
||||||
SELECT TRY_CAST(a AS INT) FROM test
|
SELECT TRY_CAST(a AS INT) FROM test
|
||||||
SELECT COALESCE(a, b, c) FROM test
|
SELECT COALESCE(a, b, c) FROM test
|
||||||
SELECT IFNULL(a, b) FROM test
|
|
||||||
SELECT ANY_VALUE(a) FROM test
|
SELECT ANY_VALUE(a) FROM test
|
||||||
SELECT 1 FROM a JOIN b ON a.x = b.x
|
SELECT 1 FROM a JOIN b ON a.x = b.x
|
||||||
SELECT 1 FROM a JOIN b AS c ON a.x = b.x
|
SELECT 1 FROM a JOIN b AS c ON a.x = b.x
|
||||||
|
@ -510,6 +512,14 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
|
||||||
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
CREATE TABLE z (a INT REFERENCES parent(b, c))
|
||||||
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
|
||||||
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE NO ACTION)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE CASCADE)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET NULL)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET DEFAULT)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE NO ACTION)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE CASCADE)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET NULL)
|
||||||
|
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET DEFAULT)
|
||||||
CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
|
CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
|
||||||
CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
|
CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
|
||||||
CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
|
CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
|
||||||
|
@ -526,6 +536,7 @@ CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DAT
|
||||||
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
|
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
|
||||||
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
|
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
|
||||||
CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
|
CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
|
||||||
|
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
|
||||||
CREATE TEMPORARY TABLE x AS SELECT a FROM d
|
CREATE TEMPORARY TABLE x AS SELECT a FROM d
|
||||||
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
|
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
|
||||||
CREATE VIEW x AS SELECT a FROM b
|
CREATE VIEW x AS SELECT a FROM b
|
||||||
|
@ -555,6 +566,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
|
||||||
CREATE SCHEMA x
|
CREATE SCHEMA x
|
||||||
CREATE SCHEMA IF NOT EXISTS y
|
CREATE SCHEMA IF NOT EXISTS y
|
||||||
CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
|
CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
|
||||||
|
CREATE OR REPLACE STAGE
|
||||||
DESCRIBE x
|
DESCRIBE x
|
||||||
DROP INDEX a.b.c
|
DROP INDEX a.b.c
|
||||||
DROP FUNCTION a.b.c (INT)
|
DROP FUNCTION a.b.c (INT)
|
||||||
|
|
|
@ -50,6 +50,10 @@ WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS ct
|
||||||
(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y));
|
(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y));
|
||||||
WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4);
|
WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4);
|
||||||
|
|
||||||
|
-- Three unions
|
||||||
|
SELECT a FROM x UNION ALL SELECT a FROM y UNION ALL SELECT a FROM z;
|
||||||
|
WITH cte AS (SELECT a FROM x), cte_2 AS (SELECT a FROM y), cte_3 AS (SELECT a FROM z), cte_4 AS (SELECT cte_2.a AS a FROM cte_2 AS cte_2 UNION ALL SELECT cte_3.a AS a FROM cte_3 AS cte_3) SELECT cte.a AS a FROM cte AS cte UNION ALL SELECT cte_4.a AS a FROM cte_4 AS cte_4;
|
||||||
|
|
||||||
-- Subquery
|
-- Subquery
|
||||||
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
|
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
|
||||||
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
|
SELECT a FROM x WHERE b = (SELECT y.c FROM y);
|
||||||
|
|
2
tests/fixtures/pretty.sql
vendored
2
tests/fixtures/pretty.sql
vendored
|
@ -99,7 +99,7 @@ WITH cte1 AS (
|
||||||
GROUPING SETS (
|
GROUPING SETS (
|
||||||
a,
|
a,
|
||||||
(b, c)
|
(b, c)
|
||||||
)
|
),
|
||||||
CUBE (
|
CUBE (
|
||||||
y,
|
y,
|
||||||
z
|
z
|
||||||
|
|
|
@ -62,6 +62,16 @@ class TestBuild(unittest.TestCase):
|
||||||
lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
|
lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
|
||||||
"SELECT x FROM tbl WHERE x < 9",
|
"SELECT x FROM tbl WHERE x < 9",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
lambda: select("x").from_("tbl").where("x > 0").lock(),
|
||||||
|
"SELECT x FROM tbl WHERE x > 0 FOR UPDATE",
|
||||||
|
"mysql",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
lambda: select("x").from_("tbl").where("x > 0").lock(update=False),
|
||||||
|
"SELECT x FROM tbl WHERE x > 0 FOR SHARE",
|
||||||
|
"postgres",
|
||||||
|
),
|
||||||
(
|
(
|
||||||
lambda: select("x", "y").from_("tbl").group_by("x"),
|
lambda: select("x", "y").from_("tbl").group_by("x"),
|
||||||
"SELECT x, y FROM tbl GROUP BY x",
|
"SELECT x, y FROM tbl GROUP BY x",
|
||||||
|
|
|
@ -466,6 +466,7 @@ class TestExpressions(unittest.TestCase):
|
||||||
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
|
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
|
||||||
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
|
self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
|
||||||
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
|
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
|
||||||
|
self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries)
|
||||||
|
|
||||||
def test_column(self):
|
def test_column(self):
|
||||||
dot = parse_one("a.b.c")
|
dot = parse_one("a.b.c")
|
||||||
|
@ -630,6 +631,19 @@ FROM foo""",
|
||||||
FROM foo""",
|
FROM foo""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_to_interval(self):
|
||||||
|
self.assertEqual(exp.to_interval("1day").sql(), "INTERVAL '1' day")
|
||||||
|
self.assertEqual(exp.to_interval(" 5 months").sql(), "INTERVAL '5' months")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
exp.to_interval("bla")
|
||||||
|
|
||||||
|
self.assertEqual(exp.to_interval(exp.Literal.string("1day")).sql(), "INTERVAL '1' day")
|
||||||
|
self.assertEqual(
|
||||||
|
exp.to_interval(exp.Literal.string(" 5 months")).sql(), "INTERVAL '5' months"
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
exp.to_interval(exp.Literal.string("bla"))
|
||||||
|
|
||||||
def test_to_table(self):
|
def test_to_table(self):
|
||||||
table_only = exp.to_table("table_name")
|
table_only = exp.to_table("table_name")
|
||||||
self.assertEqual(table_only.name, "table_name")
|
self.assertEqual(table_only.name, "table_name")
|
||||||
|
|
|
@ -326,12 +326,12 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||||
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
|
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
|
||||||
self.validate(
|
self.validate(
|
||||||
"UNIX_TO_STR(123, 'y')",
|
"UNIX_TO_STR(123, 'y')",
|
||||||
"STRFTIME(TO_TIMESTAMP(CAST(123 AS BIGINT)), 'y')",
|
"STRFTIME(TO_TIMESTAMP(123), 'y')",
|
||||||
write="duckdb",
|
write="duckdb",
|
||||||
)
|
)
|
||||||
self.validate(
|
self.validate(
|
||||||
"UNIX_TO_TIME(123)",
|
"UNIX_TO_TIME(123)",
|
||||||
"TO_TIMESTAMP(CAST(123 AS BIGINT))",
|
"TO_TIMESTAMP(123)",
|
||||||
write="duckdb",
|
write="duckdb",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -426,6 +426,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
|
||||||
mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1)
|
mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1)
|
||||||
mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1)
|
mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1)
|
||||||
|
|
||||||
|
def test_identify_lambda(self):
|
||||||
|
self.validate("x(y -> y)", 'X("y" -> "y")', identify=True)
|
||||||
|
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
self.assertEqual(transpile("")[0], "")
|
self.assertEqual(transpile("")[0], "")
|
||||||
for sql in load_sql_fixtures("identity.sql"):
|
for sql in load_sql_fixtures("identity.sql"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue