1
0
Fork 0

Merging upstream version 10.6.3.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:09:58 +01:00
parent d03a55eda6
commit ece6881255
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
48 changed files with 906 additions and 266 deletions

View file

@ -1,12 +1,12 @@
# SQLGlot # SQLGlot
SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between [19 different dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects.
It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks), while being written purely in Python.
You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL.
Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that the parser is very lenient when it comes to detecting errors, because it aims to consume as much SQL as possible. On one hand, this makes its implementation simpler, and thus more comprehensible, but on the other hand it means that syntax errors may sometimes go unnoticed. Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that SQL validation is not SQLGlots goal, so some syntax errors may go unnoticed.
Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started!
@ -432,6 +432,8 @@ user_id price
2 3.0 2 3.0
``` ```
See also: [Writing a Python SQL engine from scratch](https://github.com/tobymao/sqlglot/blob/main/posts/python_sql_engine.md).
## Used By ## Used By
* [Fugue](https://github.com/fugue-project/fugue) * [Fugue](https://github.com/fugue-project/fugue)
* [ibis](https://github.com/ibis-project/ibis) * [ibis](https://github.com/ibis-project/ibis)
@ -442,7 +444,7 @@ user_id price
## Documentation ## Documentation
SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: SQLGlot uses [pdoc](https://pdoc.dev/) to serve its API documentation:
``` ```
make docs-serve make docs-serve

View file

@ -33,7 +33,13 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema, Schema from sqlglot.schema import MappingSchema, Schema
from sqlglot.tokens import Tokenizer, TokenType from sqlglot.tokens import Tokenizer, TokenType
__version__ = "10.6.0" if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
T = t.TypeVar("T", bound=Expression)
__version__ = "10.6.3"
pretty = False pretty = False
"""Whether to format generated SQL by default.""" """Whether to format generated SQL by default."""
@ -42,9 +48,7 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer).""" """The default schema used by SQLGlot (e.g. in the optimizer)."""
def parse( def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
sql: str, read: t.Optional[str | Dialect] = None, **opts
) -> t.List[t.Optional[Expression]]:
""" """
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
@ -60,9 +64,57 @@ def parse(
return dialect.parse(sql, **opts) return dialect.parse(sql, **opts)
@t.overload
def parse_one( def parse_one(
sql: str, sql: str,
read: t.Optional[str | Dialect] = None, read: None = None,
into: t.Type[T] = ...,
**opts,
) -> T:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Type[T],
**opts,
) -> T:
...
@t.overload
def parse_one(
sql: str,
read: None = None,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]] = ...,
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
read: DialectType,
into: t.Union[str, t.Collection[t.Union[str, t.Type[Expression]]]],
**opts,
) -> Expression:
...
@t.overload
def parse_one(
sql: str,
**opts,
) -> Expression:
...
def parse_one(
sql: str,
read: DialectType = None,
into: t.Optional[exp.IntoType] = None, into: t.Optional[exp.IntoType] = None,
**opts, **opts,
) -> Expression: ) -> Expression:
@ -96,8 +148,8 @@ def parse_one(
def transpile( def transpile(
sql: str, sql: str,
read: t.Optional[str | Dialect] = None, read: DialectType = None,
write: t.Optional[str | Dialect] = None, write: DialectType = None,
identity: bool = True, identity: bool = True,
error_level: t.Optional[ErrorLevel] = None, error_level: t.Optional[ErrorLevel] = None,
**opts, **opts,

View file

@ -260,11 +260,7 @@ class Column:
""" """
if isinstance(dataType, DataType): if isinstance(dataType, DataType):
dataType = dataType.simpleString() dataType = dataType.simpleString()
new_expression = exp.Cast( return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
this=self.column_expression,
to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
)
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column: def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value value = self._lit(value) if not isinstance(value, Column) else value

View file

@ -536,15 +536,15 @@ def month(col: ColumnOrName) -> Column:
def dayofweek(col: ColumnOrName) -> Column: def dayofweek(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFWEEK") return Column.invoke_expression_over_column(col, glotexp.DayOfWeek)
def dayofmonth(col: ColumnOrName) -> Column: def dayofmonth(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFMONTH") return Column.invoke_expression_over_column(col, glotexp.DayOfMonth)
def dayofyear(col: ColumnOrName) -> Column: def dayofyear(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "DAYOFYEAR") return Column.invoke_expression_over_column(col, glotexp.DayOfYear)
def hour(col: ColumnOrName) -> Column: def hour(col: ColumnOrName) -> Column:
@ -560,7 +560,7 @@ def second(col: ColumnOrName) -> Column:
def weekofyear(col: ColumnOrName) -> Column: def weekofyear(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "WEEKOFYEAR") return Column.invoke_expression_over_column(col, glotexp.WeekOfYear)
def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
@ -1144,10 +1144,16 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge) merge_exp = _get_lambda_from_func(merge)
if finish is not None: if finish is not None:
finish_exp = _get_lambda_from_func(finish) finish_exp = _get_lambda_from_func(finish)
return Column.invoke_anonymous_function( return Column.invoke_expression_over_column(
col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp) col,
glotexp.Reduce,
initial=initialValue,
merge=Column(merge_exp),
finish=Column(finish_exp),
)
return Column.invoke_expression_over_column(
col, glotexp.Reduce, initial=initialValue, merge=Column(merge_exp)
) )
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform( def transform(

View file

@ -222,14 +222,6 @@ class BigQuery(Dialect):
exp.DataType.Type.NVARCHAR: "STRING", exp.DataType.Type.NVARCHAR: "STRING",
} }
ROOT_PROPERTIES = {
exp.LanguageProperty,
exp.ReturnsProperty,
exp.VolatilityProperty,
}
WITH_PROPERTIES = {exp.Property}
EXPLICIT_UNION = True EXPLICIT_UNION = True
def array_sql(self, expression: exp.Array) -> str: def array_sql(self, expression: exp.Array) -> str:

View file

@ -122,9 +122,15 @@ class Dialect(metaclass=_Dialect):
def get_or_raise(cls, dialect): def get_or_raise(cls, dialect):
if not dialect: if not dialect:
return cls return cls
if isinstance(dialect, _Dialect):
return dialect
if isinstance(dialect, Dialect):
return dialect.__class__
result = cls.get(dialect) result = cls.get(dialect)
if not result: if not result:
raise ValueError(f"Unknown dialect '{dialect}'") raise ValueError(f"Unknown dialect '{dialect}'")
return result return result
@classmethod @classmethod
@ -196,6 +202,10 @@ class Dialect(metaclass=_Dialect):
) )
if t.TYPE_CHECKING:
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name): def rename_func(name):
def _rename(self, expression): def _rename(self, expression):
args = flatten(expression.args.values()) args = flatten(expression.args.values())

View file

@ -137,7 +137,10 @@ class Drill(Dialect):
exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.DATETIME: "TIMESTAMP",
} }
ROOT_PROPERTIES = {exp.PartitionedByProperty} PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
}
TRANSFORMS = { TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore **generator.Generator.TRANSFORMS, # type: ignore

View file

@ -20,10 +20,6 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
def _str_to_time_sql(self, expression): def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
@ -113,7 +109,7 @@ class DuckDB(Dialect):
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRUCT_PACK": exp.Struct.from_arg_list, "STRUCT_PACK": exp.Struct.from_arg_list,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list, "UNNEST": exp.Explode.from_arg_list,
} }
@ -162,9 +158,9 @@ class DuckDB(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add, exp.TsOrDsAdd: _ts_or_ds_add,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time, exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
} }
TYPE_MAPPING = { TYPE_MAPPING = {

View file

@ -322,17 +322,11 @@ class Hive(Dialect):
exp.LastDateOfMonth: rename_func("LAST_DAY"), exp.LastDateOfMonth: rename_func("LAST_DAY"),
} }
WITH_PROPERTIES = {exp.Property} PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
ROOT_PROPERTIES = { exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.PartitionedByProperty, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.FileFormatProperty, exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
exp.RowFormatDelimitedProperty,
exp.RowFormatSerdeProperty,
exp.SerdeProperties,
} }
def with_properties(self, properties): def with_properties(self, properties):

View file

@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
@ -98,6 +96,8 @@ def _date_add_sql(kind):
class MySQL(Dialect): class MySQL(Dialect):
time_format = "'%Y-%m-%d %T'"
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
time_mapping = { time_mapping = {
"%M": "%B", "%M": "%B",
@ -110,6 +110,7 @@ class MySQL(Dialect):
"%u": "%W", "%u": "%W",
"%k": "%-H", "%k": "%-H",
"%l": "%-I", "%l": "%-I",
"%T": "%H:%M:%S",
} }
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
@ -428,6 +429,7 @@ class MySQL(Dialect):
) )
class Generator(generator.Generator): class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False NULL_ORDERING_SUPPORTED = False
TRANSFORMS = { TRANSFORMS = {
@ -449,23 +451,12 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql, exp.StrPosition: strposition_to_locate_sql,
} }
ROOT_PROPERTIES = {
exp.EngineProperty,
exp.AutoIncrementProperty,
exp.CharacterSetProperty,
exp.CollateProperty,
exp.SchemaCommentProperty,
exp.LikeProperty,
}
TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB)
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression): def show_sql(self, expression):
this = f" {expression.name}" this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else "" full = " FULL" if expression.args.get("full") else ""

View file

@ -44,6 +44,8 @@ class Oracle(Dialect):
} }
class Generator(generator.Generator): class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.TINYINT: "NUMBER",
@ -69,6 +71,7 @@ class Oracle(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.Substring: rename_func("SUBSTR"),
} }
def query_modifiers(self, expression, *sqls): def query_modifiers(self, expression, *sqls):
@ -90,6 +93,7 @@ class Oracle(Dialect):
self.sql(expression, "order"), self.sql(expression, "order"),
self.sql(expression, "offset"), # offset before limit in oracle self.sql(expression, "offset"), # offset before limit in oracle
self.sql(expression, "limit"), self.sql(expression, "limit"),
self.sql(expression, "lock"),
sep="", sep="",
) )

View file

@ -148,6 +148,22 @@ def _serial_to_generated(expression):
return expression return expression
def _generate_series(args):
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
if step is None:
# Postgres allows calls with just two arguments -- the "step" argument defaults to 1
return exp.GenerateSeries.from_arg_list(args)
if step.is_string:
args[2] = exp.to_interval(step.this)
elif isinstance(step, exp.Interval) and not step.args.get("unit"):
args[2] = exp.to_interval(step.this.this)
return exp.GenerateSeries.from_arg_list(args)
def _to_timestamp(args): def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text) # TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1: if len(args) == 1:
@ -195,29 +211,6 @@ class Postgres(Dialect):
HEX_STRINGS = [("x'", "'"), ("X'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
CREATABLES = (
"AGGREGATE",
"CAST",
"CONVERSION",
"COLLATION",
"DEFAULT CONVERSION",
"CONSTRAINT",
"DOMAIN",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SEQUENCE",
"TEXT",
"TRIGGER",
"TYPE",
"UNLOGGED",
"USER",
)
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE, "~~": TokenType.LIKE,
@ -243,8 +236,6 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY, "TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID, "UUID": TokenType.UUID,
"CSTRING": TokenType.PSEUDO_TYPE, "CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
} }
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
SINGLE_TOKENS = { SINGLE_TOKENS = {
@ -257,8 +248,10 @@ class Postgres(Dialect):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore **parser.Parser.FUNCTIONS, # type: ignore
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_TIMESTAMP": _to_timestamp, "TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
} }
BITWISE = { BITWISE = {
@ -272,6 +265,8 @@ class Postgres(Dialect):
} }
class Generator(generator.Generator): class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.TINYINT: "SMALLINT",

View file

@ -105,6 +105,29 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
def _sequence_sql(self, expression):
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
target_type = None
if isinstance(start, exp.Cast):
target_type = start.to
elif isinstance(end, exp.Cast):
target_type = end.to
if target_type and target_type.this == exp.DataType.Type.TIMESTAMP:
to = target_type.copy()
if target_type is start.to:
end = exp.Cast(this=end, to=to)
else:
start = exp.Cast(this=start, to=to)
return f"SEQUENCE({self.format_args(start, end, step)})"
def _ensure_utf8(charset): def _ensure_utf8(charset):
if charset.name.lower() != "utf-8": if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}") raise UnsupportedError(f"Unsupported charset {charset}")
@ -145,7 +168,7 @@ def _from_unixtime(args):
class Presto(Dialect): class Presto(Dialect):
index_offset = 1 index_offset = 1
null_ordering = "nulls_are_last" null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'" time_format = MySQL.time_format # type: ignore
time_mapping = MySQL.time_mapping # type: ignore time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
@ -197,7 +220,10 @@ class Presto(Dialect):
class Generator(generator.Generator): class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")") STRUCT_DELIMITER = ("(", ")")
ROOT_PROPERTIES = {exp.SchemaCommentProperty} PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.LocationProperty: exp.Properties.Location.UNSUPPORTED,
}
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore **generator.Generator.TYPE_MAPPING, # type: ignore
@ -223,6 +249,7 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql, exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@ -231,6 +258,7 @@ class Presto(Dialect):
exp.Decode: _decode_sql, exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql, exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
exp.Hex: rename_func("TO_HEX"), exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql, exp.If: if_sql,
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,

View file

@ -61,14 +61,9 @@ class Redshift(Postgres):
exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.INT: "INTEGER",
} }
ROOT_PROPERTIES = { PROPERTIES_LOCATION = {
exp.DistKeyProperty, **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore
exp.SortKeyProperty, exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.DistStyleProperty,
}
WITH_PROPERTIES = {
exp.LikeProperty,
} }
TRANSFORMS = { TRANSFORMS = {

View file

@ -234,15 +234,6 @@ class Snowflake(Dialect):
"replace": "RENAME", "replace": "RENAME",
} }
ROOT_PROPERTIES = {
exp.PartitionedByProperty,
exp.ReturnsProperty,
exp.LanguageProperty,
exp.SchemaCommentProperty,
exp.ExecuteAsProperty,
exp.VolatilityProperty,
}
def except_op(self, expression): def except_op(self, expression):
if not expression.args.get("distinct", False): if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake") self.unsupported("EXCEPT with All is not supported in Snowflake")

View file

@ -73,6 +73,19 @@ class Spark(Hive):
), ),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list, "IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
"DAYOFWEEK": lambda args: exp.DayOfWeek(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DAYOFYEAR": lambda args: exp.DayOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
} }
FUNCTION_PARSERS = { FUNCTION_PARSERS = {
@ -105,6 +118,14 @@ class Spark(Hive):
exp.DataType.Type.BIGINT: "LONG", exp.DataType.Type.BIGINT: "LONG",
} }
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION, # type: ignore
exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
}
TRANSFORMS = { TRANSFORMS = {
**Hive.Generator.TRANSFORMS, # type: ignore **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
@ -126,11 +147,27 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"), exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalOr: rename_func("BOOL_OR"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
} }
TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False WRAP_DERIVED_VALUES = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
if expression.to.is_type(exp.DataType.Type.JSON):
return f"TO_JSON({self.sql(expression, 'this')})"
return super(Spark.Generator, self).cast_sql(expression)
class Tokenizer(Hive.Tokenizer): class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'")] HEX_STRINGS = [("X'", "'")]

View file

@ -31,6 +31,5 @@ class Tableau(Dialect):
class Parser(parser.Parser): class Parser(parser.Parser):
FUNCTIONS = { FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore **parser.Parser.FUNCTIONS, # type: ignore
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
} }

View file

@ -76,6 +76,14 @@ class Teradata(Dialect):
) )
class Generator(generator.Generator): class Generator(generator.Generator):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
return f"PARTITION BY {self.sql(expression, 'this')}"
# FROM before SET in Teradata UPDATE syntax # FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def update_sql(self, expression: exp.Update) -> str: def update_sql(self, expression: exp.Update) -> str:

View file

@ -412,6 +412,8 @@ class TSQL(Dialect):
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
class Generator(generator.Generator): class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = { TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.BOOLEAN: "BIT",

View file

@ -14,10 +14,6 @@ from sqlglot import Dialect
from sqlglot import expressions as exp from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection from sqlglot.helper import ensure_collection
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
@dataclass(frozen=True) @dataclass(frozen=True)
class Insert: class Insert:
@ -56,6 +52,11 @@ class Keep:
target: exp.Expression target: exp.Expression
if t.TYPE_CHECKING:
T = t.TypeVar("T")
Edit = t.Union[Insert, Remove, Move, Update, Keep]
def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]: def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
""" """
Returns the list of changes between the source and the target expressions. Returns the list of changes between the source and the target expressions.

View file

@ -1,5 +1,13 @@
"""
.. include:: ../../posts/python_sql_engine.md
----
"""
from __future__ import annotations
import logging import logging
import time import time
import typing as t
from sqlglot import maybe_parse from sqlglot import maybe_parse
from sqlglot.errors import ExecuteError from sqlglot.errors import ExecuteError
@ -11,42 +19,63 @@ from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlglot.executor.table import Tables
from sqlglot.expressions import Expression
from sqlglot.schema import Schema
def execute(sql, schema=None, read=None, tables=None):
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
read: DialectType = None,
tables: t.Optional[t.Dict] = None,
) -> Table:
""" """
Run a sql query against data. Run a sql query against data.
Args: Args:
sql (str|sqlglot.Expression): a sql statement sql: a sql statement.
schema (dict|sqlglot.optimizer.Schema): database schema. schema: database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of This can either be an instance of `Schema` or a mapping in one of the following forms:
the following forms:
1. {table: {col: type}} 1. {table: {col: type}}
2. {db: {table: {col: type}}} 2. {db: {table: {col: type}}}
3. {catalog: {db: {table: {col: type}}}} 3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
(eg. "spark", "hive", "presto", "mysql"). tables: additional tables to register.
tables (dict): additional tables to register.
Returns: Returns:
sqlglot.executor.Table: Simple columnar data structure. Simple columnar data structure.
""" """
tables = ensure_tables(tables) tables_ = ensure_tables(tables)
if not schema: if not schema:
schema = { schema = {
name: {column: type(table[0][column]).__name__ for column in table.columns} name: {column: type(table[0][column]).__name__ for column in table.columns}
for name, table in tables.mapping.items() for name, table in tables_.mapping.items()
} }
schema = ensure_schema(schema) schema = ensure_schema(schema)
if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema") raise ExecuteError("Tables must support the same table args as schema")
expression = maybe_parse(sql, dialect=read) expression = maybe_parse(sql, dialect=read)
now = time.time() now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True) expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now) logger.debug("Optimization finished: %f", time.time() - now)
logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
plan = Plan(expression) plan = Plan(expression)
logger.debug("Logical Plan: %s", plan) logger.debug("Logical Plan: %s", plan)
now = time.time() now = time.time()
result = PythonExecutor(tables=tables).execute(plan) result = PythonExecutor(tables=tables_).execute(plan)
logger.debug("Query finished: %f", time.time() - now) logger.debug("Query finished: %f", time.time() - now)
return result return result

View file

@ -171,5 +171,6 @@ ENV = {
"STRPOSITION": str_position, "STRPOSITION": str_position,
"SUB": null_if_any(lambda e, this: e - this), "SUB": null_if_any(lambda e, this: e - this),
"SUBSTRING": substring, "SUBSTRING": substring,
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
"UPPER": null_if_any(lambda arg: arg.upper()), "UPPER": null_if_any(lambda arg: arg.upper()),
} }

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from sqlglot.helper import dict_depth from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema from sqlglot.schema import AbstractMappingSchema
@ -106,11 +108,11 @@ class Tables(AbstractMappingSchema[Table]):
pass pass
def ensure_tables(d: dict | None) -> Tables: def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
return Tables(_ensure_tables(d)) return Tables(_ensure_tables(d))
def _ensure_tables(d: dict | None) -> dict: def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
if not d: if not d:
return {} return {}
@ -127,4 +129,5 @@ def _ensure_tables(d: dict | None) -> dict:
columns = tuple(table[0]) if table else () columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table] rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows) result[name] = Table(columns=columns, rows=rows)
return result return result

View file

@ -32,13 +32,7 @@ from sqlglot.helper import (
from sqlglot.tokens import Token from sqlglot.tokens import Token
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect from sqlglot.dialects.dialect import DialectType
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class _Expression(type): class _Expression(type):
@ -427,7 +421,7 @@ class Expression(metaclass=_Expression):
def __repr__(self): def __repr__(self):
return self._to_s() return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str: def sql(self, dialect: DialectType = None, **opts) -> str:
""" """
Returns SQL string representation of this tree. Returns SQL string representation of this tree.
@ -595,6 +589,14 @@ class Expression(metaclass=_Expression):
return load(obj) return load(obj)
if t.TYPE_CHECKING:
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
class Condition(Expression): class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts): def and_(self, *expressions, dialect=None, **opts):
""" """
@ -1285,6 +1287,18 @@ class Property(Expression):
arg_types = {"this": True, "value": True} arg_types = {"this": True, "value": True}
class AlgorithmProperty(Property):
arg_types = {"this": True}
class DefinerProperty(Property):
arg_types = {"this": True}
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
class TableFormatProperty(Property): class TableFormatProperty(Property):
arg_types = {"this": True} arg_types = {"this": True}
@ -1425,13 +1439,15 @@ class IsolatedLoadingProperty(Property):
class Properties(Expression): class Properties(Expression):
arg_types = {"expressions": True, "before": False} arg_types = {"expressions": True}
NAME_TO_PROPERTY = { NAME_TO_PROPERTY = {
"ALGORITHM": AlgorithmProperty,
"AUTO_INCREMENT": AutoIncrementProperty, "AUTO_INCREMENT": AutoIncrementProperty,
"CHARACTER SET": CharacterSetProperty, "CHARACTER SET": CharacterSetProperty,
"COLLATE": CollateProperty, "COLLATE": CollateProperty,
"COMMENT": SchemaCommentProperty, "COMMENT": SchemaCommentProperty,
"DEFINER": DefinerProperty,
"DISTKEY": DistKeyProperty, "DISTKEY": DistKeyProperty,
"DISTSTYLE": DistStyleProperty, "DISTSTYLE": DistStyleProperty,
"ENGINE": EngineProperty, "ENGINE": EngineProperty,
@ -1447,6 +1463,14 @@ class Properties(Expression):
PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()}
class Location(AutoName):
POST_CREATE = auto()
PRE_SCHEMA = auto()
POST_INDEX = auto()
POST_SCHEMA_ROOT = auto()
POST_SCHEMA_WITH = auto()
UNSUPPORTED = auto()
@classmethod @classmethod
def from_dict(cls, properties_dict) -> Properties: def from_dict(cls, properties_dict) -> Properties:
expressions = [] expressions = []
@ -1592,6 +1616,7 @@ QUERY_MODIFIERS = {
"order": False, "order": False,
"limit": False, "limit": False,
"offset": False, "offset": False,
"lock": False,
} }
@ -1713,6 +1738,12 @@ class Schema(Expression):
arg_types = {"this": False, "expressions": False} arg_types = {"this": False, "expressions": False}
# Used to represent the FOR UPDATE and FOR SHARE locking read types.
# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html
class Lock(Expression):
arg_types = {"update": True}
class Select(Subqueryable): class Select(Subqueryable):
arg_types = { arg_types = {
"with": False, "with": False,
@ -2243,6 +2274,30 @@ class Select(Subqueryable):
properties=properties_expression, properties=properties_expression,
) )
def lock(self, update: bool = True, copy: bool = True) -> Select:
"""
Set the locking read mode for this expression.
Examples:
>>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR UPDATE"
>>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql")
"SELECT x FROM tbl WHERE x = 'a' FOR SHARE"
Args:
update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`.
copy: if `False`, modify this expression instance in-place.
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst.set("lock", Lock(update=update))
return inst
@property @property
def named_selects(self) -> t.List[str]: def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name] return [e.output_name for e in self.expressions if e.alias_or_name]
@ -2456,24 +2511,28 @@ class DataType(Expression):
@classmethod @classmethod
def build( def build(
cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
) -> DataType: ) -> DataType:
from sqlglot import parse_one from sqlglot import parse_one
if isinstance(dtype, str): if isinstance(dtype, str):
data_type_exp: t.Optional[Expression]
if dtype.upper() in cls.Type.__members__: if dtype.upper() in cls.Type.__members__:
data_type_exp = DataType(this=DataType.Type[dtype.upper()]) data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
else: else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType) data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if data_type_exp is None: if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}") raise ValueError(f"Unparsable data type value: {dtype}")
elif isinstance(dtype, DataType.Type): elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype) data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
return dtype
else: else:
raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
return DataType(**{**data_type_exp.args, **kwargs}) return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, dtype: DataType.Type) -> bool:
return self.this == dtype
# https://www.postgresql.org/docs/15/datatype-pseudo.html # https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(Expression): class PseudoType(Expression):
@ -2840,6 +2899,10 @@ class Array(Func):
is_var_len_args = True is_var_len_args = True
class GenerateSeries(Func):
arg_types = {"start": True, "end": True, "step": False}
class ArrayAgg(AggFunc): class ArrayAgg(AggFunc):
pass pass
@ -2909,6 +2972,9 @@ class Cast(Func):
def output_name(self): def output_name(self):
return self.name return self.name
def is_type(self, dtype: DataType.Type) -> bool:
return self.to.is_type(dtype)
class Collate(Binary): class Collate(Binary):
pass pass
@ -2989,6 +3055,22 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False} arg_types = {"this": True, "unit": True, "zone": False}
class DayOfWeek(Func):
_sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"]
class DayOfMonth(Func):
_sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"]
class DayOfYear(Func):
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
class LastDateOfMonth(Func): class LastDateOfMonth(Func):
pass pass
@ -3239,7 +3321,7 @@ class ReadCSV(Func):
class Reduce(Func): class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True} arg_types = {"this": True, "initial": True, "merge": True, "finish": False}
class RegexpLike(Func): class RegexpLike(Func):
@ -3476,7 +3558,7 @@ def maybe_parse(
sql_or_expression: str | Expression, sql_or_expression: str | Expression,
*, *,
into: t.Optional[IntoType] = None, into: t.Optional[IntoType] = None,
dialect: t.Optional[str] = None, dialect: DialectType = None,
prefix: t.Optional[str] = None, prefix: t.Optional[str] = None,
**opts, **opts,
) -> Expression: ) -> Expression:
@ -3959,6 +4041,28 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier return identifier
INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")
def to_interval(interval: str | Literal) -> Interval:
"""Builds an interval expression from a string like '1 day' or '5 months'."""
if isinstance(interval, Literal):
if not interval.is_string:
raise ValueError("Invalid interval string.")
interval = interval.this
interval_parts = INTERVAL_STRING_RE.match(interval) # type: ignore
if not interval_parts:
raise ValueError("Invalid interval string.")
return Interval(
this=Literal.string(interval_parts.group(1)),
unit=Var(this=interval_parts.group(2)),
)
@t.overload @t.overload
def to_table(sql_path: str | Table, **kwargs) -> Table: def to_table(sql_path: str | Table, **kwargs) -> Table:
... ...
@ -4050,7 +4154,8 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
def subquery(expression, alias=None, dialect=None, **opts): def subquery(expression, alias=None, dialect=None, **opts):
""" """
Build a subquery expression. Build a subquery expression.
Expample:
Example:
>>> subquery('select x from tbl', 'bar').select('x').sql() >>> subquery('select x from tbl', 'bar').select('x').sql()
'SELECT x FROM (SELECT x FROM tbl) AS bar' 'SELECT x FROM (SELECT x FROM tbl) AS bar'
@ -4072,6 +4177,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(col, table=None, quoted=None) -> Column: def column(col, table=None, quoted=None) -> Column:
""" """
Build a Column. Build a Column.
Args: Args:
col (str | Expression): column name col (str | Expression): column name
table (str | Expression): table name table (str | Expression): table name
@ -4084,6 +4190,24 @@ def column(col, table=None, quoted=None) -> Column:
) )
def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast:
"""Cast an expression to a data type.
Example:
>>> cast('x + 1', 'int').sql()
'CAST(x + 1 AS INT)'
Args:
expression: The expression to cast.
to: The datatype to cast to.
Returns:
A cast node.
"""
expression = maybe_parse(expression, **opts)
return Cast(this=expression, to=DataType.build(to, **opts))
def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table. """Build a Table.
@ -4137,7 +4261,7 @@ def values(
types = list(columns.values()) types = list(columns.values())
expressions[0].set( expressions[0].set(
"expressions", "expressions",
[Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
) )
return Values( return Values(
expressions=expressions, expressions=expressions,
@ -4373,7 +4497,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True
return expression.transform(_expand, copy=copy) return expression.transform(_expand, copy=copy)
def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func:
""" """
Returns a Func expression. Returns a Func expression.

View file

@ -67,6 +67,7 @@ class Generator:
exp.VolatilityProperty: lambda self, e: e.name, exp.VolatilityProperty: lambda self, e: e.name,
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
} }
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@ -75,6 +76,9 @@ class Generator:
# Whether or not null ordering is supported in order by # Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True NULL_ORDERING_SUPPORTED = True
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all # Always do union distinct or union all
EXPLICIT_UNION = False EXPLICIT_UNION = False
@ -99,34 +103,42 @@ class Generator:
STRUCT_DELIMITER = ("<", ">") STRUCT_DELIMITER = ("<", ">")
BEFORE_PROPERTIES = { PROPERTIES_LOCATION = {
exp.FallbackProperty, exp.AfterJournalProperty: exp.Properties.Location.PRE_SCHEMA,
exp.WithJournalTableProperty, exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.LogProperty, exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.JournalProperty, exp.BlockCompressionProperty: exp.Properties.Location.PRE_SCHEMA,
exp.AfterJournalProperty, exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.ChecksumProperty, exp.ChecksumProperty: exp.Properties.Location.PRE_SCHEMA,
exp.FreespaceProperty, exp.CollateProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.MergeBlockRatioProperty, exp.DataBlocksizeProperty: exp.Properties.Location.PRE_SCHEMA,
exp.DataBlocksizeProperty, exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
exp.BlockCompressionProperty, exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.IsolatedLoadingProperty, exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
} exp.EngineProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
ROOT_PROPERTIES = { exp.FallbackProperty: exp.Properties.Location.PRE_SCHEMA,
exp.ReturnsProperty, exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.LanguageProperty, exp.FreespaceProperty: exp.Properties.Location.PRE_SCHEMA,
exp.DistStyleProperty, exp.IsolatedLoadingProperty: exp.Properties.Location.PRE_SCHEMA,
exp.DistKeyProperty, exp.JournalProperty: exp.Properties.Location.PRE_SCHEMA,
exp.SortKeyProperty, exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.LikeProperty, exp.LikeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
} exp.LocationProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.LogProperty: exp.Properties.Location.PRE_SCHEMA,
WITH_PROPERTIES = { exp.MergeBlockRatioProperty: exp.Properties.Location.PRE_SCHEMA,
exp.Property, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.FileFormatProperty, exp.Property: exp.Properties.Location.POST_SCHEMA_WITH,
exp.PartitionedByProperty, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.TableFormatProperty, exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA_WITH,
exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA_ROOT,
exp.WithJournalTableProperty: exp.Properties.Location.PRE_SCHEMA,
} }
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
@ -284,10 +296,10 @@ class Generator:
) )
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
def no_identify(self, func: t.Callable[[], str]) -> str: def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str:
original = self.identify original = self.identify
self.identify = False self.identify = False
result = func() result = func(*args, **kwargs)
self.identify = original self.identify = original
return result return result
@ -455,19 +467,33 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str: def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper() kind = self.sql(expression, "kind").upper()
has_before_properties = expression.args.get("properties") properties = expression.args.get("properties")
has_before_properties = ( properties_exp = expression.copy()
has_before_properties.args.get("before") if has_before_properties else None properties_locs = self.locate_properties(properties) if properties else {}
if properties_locs.get(exp.Properties.Location.POST_SCHEMA_ROOT) or properties_locs.get(
exp.Properties.Location.POST_SCHEMA_WITH
):
properties_exp.set(
"properties",
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT],
*properties_locs[exp.Properties.Location.POST_SCHEMA_WITH],
]
),
) )
if kind == "TABLE" and has_before_properties: if kind == "TABLE" and properties_locs.get(exp.Properties.Location.PRE_SCHEMA):
this_name = self.sql(expression.this, "this") this_name = self.sql(expression.this, "this")
this_properties = self.sql(expression, "properties") this_properties = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.PRE_SCHEMA]),
wrapped=False,
)
this_schema = f"({self.expressions(expression.this)})" this_schema = f"({self.expressions(expression.this)})"
this = f"{this_name}, {this_properties} {this_schema}" this = f"{this_name}, {this_properties} {this_schema}"
properties = "" properties_sql = ""
else: else:
this = self.sql(expression, "this") this = self.sql(expression, "this")
properties = self.sql(expression, "properties") properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else "" begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
@ -514,11 +540,31 @@ class Generator:
if index.args.get("columns") if index.args.get("columns")
else "" else ""
) )
if index.args.get("primary") and properties_locs.get(
exp.Properties.Location.POST_INDEX
):
postindex_props_sql = self.properties(
exp.Properties(
expressions=properties_locs[exp.Properties.Location.POST_INDEX]
),
wrapped=False,
)
ind_columns = f"{ind_columns} {postindex_props_sql}"
indexes_sql.append( indexes_sql.append(
f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
) )
index_sql = "".join(indexes_sql) index_sql = "".join(indexes_sql)
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
postcreate_props_sql = self.properties(
exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]),
sep=" ",
prefix=" ",
wrapped=False,
)
modifiers = "".join( modifiers = "".join(
( (
replace, replace,
@ -531,6 +577,7 @@ class Generator:
multiset, multiset,
global_temporary, global_temporary,
volatile, volatile,
postcreate_props_sql,
) )
) )
no_schema_binding = ( no_schema_binding = (
@ -539,7 +586,7 @@ class Generator:
post_expression_modifiers = "".join((data, statistics, no_primary_index)) post_expression_modifiers = "".join((data, statistics, no_primary_index))
expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql) return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str: def describe_sql(self, expression: exp.Describe) -> str:
@ -665,24 +712,19 @@ class Generator:
return f"PARTITION({self.expressions(expression)})" return f"PARTITION({self.expressions(expression)})"
def properties_sql(self, expression: exp.Properties) -> str: def properties_sql(self, expression: exp.Properties) -> str:
before_properties = []
root_properties = [] root_properties = []
with_properties = [] with_properties = []
for p in expression.expressions: for p in expression.expressions:
p_class = p.__class__ p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_class in self.BEFORE_PROPERTIES: if p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
before_properties.append(p)
elif p_class in self.WITH_PROPERTIES:
with_properties.append(p) with_properties.append(p)
elif p_class in self.ROOT_PROPERTIES: elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
root_properties.append(p) root_properties.append(p)
return ( return self.root_properties(
self.properties(exp.Properties(expressions=before_properties), before=True) exp.Properties(expressions=root_properties)
+ self.root_properties(exp.Properties(expressions=root_properties)) ) + self.with_properties(exp.Properties(expressions=with_properties))
+ self.with_properties(exp.Properties(expressions=with_properties))
)
def root_properties(self, properties: exp.Properties) -> str: def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions: if properties.expressions:
@ -695,17 +737,41 @@ class Generator:
prefix: str = "", prefix: str = "",
sep: str = ", ", sep: str = ", ",
suffix: str = "", suffix: str = "",
before: bool = False, wrapped: bool = True,
) -> str: ) -> str:
if properties.expressions: if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False) expressions = self.expressions(properties, sep=sep, indent=False)
expressions = expressions if before else self.wrap(expressions) expressions = self.wrap(expressions) if wrapped else expressions
return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return "" return ""
def with_properties(self, properties: exp.Properties) -> str: def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH")) return self.properties(properties, prefix=self.seg("WITH"))
def locate_properties(
self, properties: exp.Properties
) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
key: [] for key in exp.Properties.Location
}
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.PRE_SCHEMA:
properties_locs[exp.Properties.Location.PRE_SCHEMA].append(p)
elif p_loc == exp.Properties.Location.POST_INDEX:
properties_locs[exp.Properties.Location.POST_INDEX].append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA_ROOT:
properties_locs[exp.Properties.Location.POST_SCHEMA_ROOT].append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA_WITH:
properties_locs[exp.Properties.Location.POST_SCHEMA_WITH].append(p)
elif p_loc == exp.Properties.Location.POST_CREATE:
properties_locs[exp.Properties.Location.POST_CREATE].append(p)
elif p_loc == exp.Properties.Location.UNSUPPORTED:
self.unsupported(f"Unsupported property {p.key}")
return properties_locs
def property_sql(self, expression: exp.Property) -> str: def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__ property_cls = expression.__class__
if property_cls == exp.Property: if property_cls == exp.Property:
@ -713,7 +779,7 @@ class Generator:
property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
if not property_name: if not property_name:
self.unsupported(f"Unsupported property {property_name}") self.unsupported(f"Unsupported property {expression.key}")
return f"{property_name}={self.sql(expression, 'this')}" return f"{property_name}={self.sql(expression, 'this')}"
@ -975,7 +1041,7 @@ class Generator:
rollup = self.expressions(expression, key="rollup", indent=False) rollup = self.expressions(expression, key="rollup", indent=False)
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}" return f"{group_by}{csv(grouping_sets, cube, rollup, sep=',')}"
def having_sql(self, expression: exp.Having) -> str: def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this")) this = self.indent(self.sql(expression, "this"))
@ -1015,7 +1081,7 @@ class Generator:
def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True) args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
def lateral_sql(self, expression: exp.Lateral) -> str: def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
@ -1043,6 +1109,14 @@ class Generator:
this = self.sql(expression, "this") this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
return self.seg(f"FOR {lock_type}")
self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported")
return ""
def literal_sql(self, expression: exp.Literal) -> str: def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or "" text = expression.this or ""
if expression.is_string: if expression.is_string:
@ -1163,6 +1237,7 @@ class Generator:
self.sql(expression, "order"), self.sql(expression, "order"),
self.sql(expression, "limit"), self.sql(expression, "limit"),
self.sql(expression, "offset"), self.sql(expression, "offset"),
self.sql(expression, "lock"),
sep="", sep="",
) )
@ -1773,7 +1848,7 @@ class Generator:
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression)) expressions = self.no_identify(self.expressions, expression)
expressions = ( expressions = (
self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
) )

View file

@ -9,6 +9,9 @@ from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.qualify_tables import qualify_tables
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
@dataclass(frozen=True) @dataclass(frozen=True)
class Node: class Node:
@ -36,7 +39,7 @@ def lineage(
schema: t.Optional[t.Dict | Schema] = None, schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
dialect: t.Optional[str] = None, dialect: DialectType = None,
) -> Node: ) -> Node:
"""Build the lineage graph for a column of a SQL query. """Build the lineage graph for a column of a SQL query.
@ -126,7 +129,7 @@ class LineageHTML:
def __init__( def __init__(
self, self,
node: Node, node: Node,
dialect: t.Optional[str] = None, dialect: DialectType = None,
imports: bool = True, imports: bool = True,
**opts: t.Any, **opts: t.Any,
): ):

View file

@ -114,7 +114,7 @@ def _eliminate_union(scope, existing_ctes, taken):
taken[alias] = scope taken[alias] = scope
# Try to maintain the selections # Try to maintain the selections
expressions = scope.expression.args.get("expressions") expressions = scope.selects
selects = [ selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name)
for e in expressions for e in expressions

View file

@ -300,7 +300,7 @@ class Scope:
list[exp.Expression]: expressions list[exp.Expression]: expressions
""" """
if isinstance(self.expression, exp.Union): if isinstance(self.expression, exp.Union):
return [] return self.expression.unnest().selects
return self.expression.selects return self.expression.selects
@property @property

View file

@ -456,8 +456,10 @@ def extract_interval(interval):
def date_literal(date): def date_literal(date):
expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") return exp.cast(
return exp.Cast(this=exp.Literal.string(date), to=expr_type) exp.Literal.string(date),
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
)
def boolean_literal(condition): def boolean_literal(condition):

View file

@ -80,6 +80,7 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10), length=exp.Literal.number(10),
), ),
"VAR_MAP": parse_var_map, "VAR_MAP": parse_var_map,
"IFNULL": exp.Coalesce.from_arg_list,
} }
NO_PAREN_FUNCTIONS = { NO_PAREN_FUNCTIONS = {
@ -567,6 +568,8 @@ class Parser(metaclass=_Parser):
default=self._prev.text.upper() == "DEFAULT" default=self._prev.text.upper() == "DEFAULT"
), ),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"DEFINER": lambda self: self._parse_definer(),
} }
CONSTRAINT_PARSERS = { CONSTRAINT_PARSERS = {
@ -608,6 +611,7 @@ class Parser(metaclass=_Parser):
"order": lambda self: self._parse_order(), "order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(), "limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(), "offset": lambda self: self._parse_offset(),
"lock": lambda self: self._parse_lock(),
} }
SHOW_PARSERS: t.Dict[str, t.Callable] = {} SHOW_PARSERS: t.Dict[str, t.Callable] = {}
@ -850,7 +854,7 @@ class Parser(metaclass=_Parser):
self.raise_error(error_message) self.raise_error(error_message)
def _find_sql(self, start: Token, end: Token) -> str: def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[self._find_token(start) : self._find_token(end)] return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)]
def _find_token(self, token: Token) -> int: def _find_token(self, token: Token) -> int:
line = 1 line = 1
@ -901,6 +905,7 @@ class Parser(metaclass=_Parser):
return expression return expression
def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
start = self._prev
temporary = self._match(TokenType.TEMPORARY) temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED) materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text kind = self._match_set(self.CREATABLES) and self._prev.text
@ -908,8 +913,7 @@ class Parser(metaclass=_Parser):
if default_kind: if default_kind:
kind = default_kind kind = default_kind
else: else:
self.raise_error(f"Expected {self.CREATABLES}") return self._parse_as_command(start)
return None
return self.expression( return self.expression(
exp.Drop, exp.Drop,
@ -929,6 +933,7 @@ class Parser(metaclass=_Parser):
) )
def _parse_create(self) -> t.Optional[exp.Expression]: def _parse_create(self) -> t.Optional[exp.Expression]:
start = self._prev
replace = self._match_pair(TokenType.OR, TokenType.REPLACE) replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
set_ = self._match(TokenType.SET) # Teradata set_ = self._match(TokenType.SET) # Teradata
multiset = self._match_text_seq("MULTISET") # Teradata multiset = self._match_text_seq("MULTISET") # Teradata
@ -943,16 +948,19 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE) self._match(TokenType.TABLE)
properties = None
create_token = self._match_set(self.CREATABLES) and self._prev create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token: if not create_token:
self.raise_error(f"Expected {self.CREATABLES}") properties = self._parse_properties()
return None create_token = self._match_set(self.CREATABLES) and self._prev
if not properties or not create_token:
return self._parse_as_command(start)
exists = self._parse_exists(not_=True) exists = self._parse_exists(not_=True)
this = None this = None
expression = None expression = None
properties = None
data = None data = None
statistics = None statistics = None
no_primary_index = None no_primary_index = None
@ -1006,6 +1014,14 @@ class Parser(metaclass=_Parser):
indexes = [] indexes = []
while True: while True:
index = self._parse_create_table_index() index = self._parse_create_table_index()
# post index PARTITION BY property
if self._match(TokenType.PARTITION_BY, advance=False):
if properties:
properties.expressions.append(self._parse_property())
else:
properties = self._parse_properties()
if not index: if not index:
break break
else: else:
@ -1040,6 +1056,9 @@ class Parser(metaclass=_Parser):
) )
def _parse_property_before(self) -> t.Optional[exp.Expression]: def _parse_property_before(self) -> t.Optional[exp.Expression]:
self._match(TokenType.COMMA)
# parsers look to _prev for no/dual/default, so need to consume first
self._match_text_seq("NO") self._match_text_seq("NO")
self._match_text_seq("DUAL") self._match_text_seq("DUAL")
self._match_text_seq("DEFAULT") self._match_text_seq("DEFAULT")
@ -1059,6 +1078,9 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True) return self._parse_sortkey(compound=True)
if self._match_text_seq("SQL", "SECURITY"):
return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER"))
assignment = self._match_pair( assignment = self._match_pair(
TokenType.VAR, TokenType.EQ, advance=False TokenType.VAR, TokenType.EQ, advance=False
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
@ -1083,7 +1105,6 @@ class Parser(metaclass=_Parser):
while True: while True:
if before: if before:
self._match(TokenType.COMMA)
identified_property = self._parse_property_before() identified_property = self._parse_property_before()
else: else:
identified_property = self._parse_property() identified_property = self._parse_property()
@ -1094,7 +1115,7 @@ class Parser(metaclass=_Parser):
properties.append(p) properties.append(p)
if properties: if properties:
return self.expression(exp.Properties, expressions=properties, before=before) return self.expression(exp.Properties, expressions=properties)
return None return None
@ -1118,6 +1139,19 @@ class Parser(metaclass=_Parser):
return self._parse_withisolatedloading() return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
def _parse_definer(self) -> t.Optional[exp.Expression]:
self._match(TokenType.EQ)
user = self._parse_id_var()
self._match(TokenType.PARAMETER)
host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text)
if not user or not host:
return None
return exp.DefinerProperty(this=f"{user}@{host}")
def _parse_withjournaltable(self) -> exp.Expression: def _parse_withjournaltable(self) -> exp.Expression:
self._match_text_seq("WITH", "JOURNAL", "TABLE") self._match_text_seq("WITH", "JOURNAL", "TABLE")
self._match(TokenType.EQ) self._match(TokenType.EQ)
@ -1695,12 +1729,10 @@ class Parser(metaclass=_Parser):
paren += 1 paren += 1
if self._curr.token_type == TokenType.R_PAREN: if self._curr.token_type == TokenType.R_PAREN:
paren -= 1 paren -= 1
end = self._prev
self._advance() self._advance()
if paren > 0: if paren > 0:
self.raise_error("Expecting )", self._curr) self.raise_error("Expecting )", self._curr)
if not self._curr:
self.raise_error("Expecting pattern", self._curr)
end = self._prev
pattern = exp.Var(this=self._find_sql(start, end)) pattern = exp.Var(this=self._find_sql(start, end))
else: else:
pattern = None pattern = None
@ -2044,9 +2076,16 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_conjunction) expressions = self._parse_csv(self._parse_conjunction)
grouping_sets = self._parse_grouping_sets() grouping_sets = self._parse_grouping_sets()
self._match(TokenType.COMMA)
with_ = self._match(TokenType.WITH) with_ = self._match(TokenType.WITH)
cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars()) cube = self._match(TokenType.CUBE) and (
rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars()) with_ or self._parse_wrapped_csv(self._parse_column)
)
self._match(TokenType.COMMA)
rollup = self._match(TokenType.ROLLUP) and (
with_ or self._parse_wrapped_csv(self._parse_column)
)
return self.expression( return self.expression(
exp.Group, exp.Group,
@ -2149,6 +2188,14 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS)) self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count) return self.expression(exp.Offset, this=this, expression=count)
def _parse_lock(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("FOR", "UPDATE"):
return self.expression(exp.Lock, update=True)
if self._match_text_seq("FOR", "SHARE"):
return self.expression(exp.Lock, update=False)
return None
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS): if not self._match_set(self.SET_OPERATIONS):
return this return this
@ -2330,12 +2377,21 @@ class Parser(metaclass=_Parser):
maybe_func = True maybe_func = True
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType( this = exp.DataType(
this=exp.DataType.Type.ARRAY, this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)], expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
nested=True, nested=True,
) )
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[this],
nested=True,
)
return this
if self._match(TokenType.L_BRACKET): if self._match(TokenType.L_BRACKET):
self._retreat(index) self._retreat(index)
return None return None
@ -2430,7 +2486,12 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected type") self.raise_error("Expected type")
elif op: elif op:
self._advance() self._advance()
field = exp.Literal.string(self._prev.text) value = self._prev.text
field = (
exp.Literal.number(value)
if self._prev.token_type == TokenType.NUMBER
else exp.Literal.string(value)
)
else: else:
field = self._parse_star() or self._parse_function() or self._parse_id_var() field = self._parse_star() or self._parse_function() or self._parse_id_var()
@ -2752,7 +2813,23 @@ class Parser(metaclass=_Parser):
if not self._curr: if not self._curr:
break break
if self._match_text_seq("NOT", "ENFORCED"): if self._match(TokenType.ON):
action = None
on = self._advance_any() and self._prev.text
if self._match(TokenType.NO_ACTION):
action = "NO ACTION"
elif self._match(TokenType.CASCADE):
action = "CASCADE"
elif self._match_pair(TokenType.SET, TokenType.NULL):
action = "SET NULL"
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
action = "SET DEFAULT"
else:
self.raise_error("Invalid key constraint")
options.append(f"ON {on} {action}")
elif self._match_text_seq("NOT", "ENFORCED"):
options.append("NOT ENFORCED") options.append("NOT ENFORCED")
elif self._match_text_seq("DEFERRABLE"): elif self._match_text_seq("DEFERRABLE"):
options.append("DEFERRABLE") options.append("DEFERRABLE")
@ -2762,10 +2839,6 @@ class Parser(metaclass=_Parser):
options.append("NORELY") options.append("NORELY")
elif self._match_text_seq("MATCH", "FULL"): elif self._match_text_seq("MATCH", "FULL"):
options.append("MATCH FULL") options.append("MATCH FULL")
elif self._match_text_seq("ON", "UPDATE", "NO ACTION"):
options.append("ON UPDATE NO ACTION")
elif self._match_text_seq("ON", "DELETE", "NO ACTION"):
options.append("ON DELETE NO ACTION")
else: else:
break break
@ -3158,7 +3231,9 @@ class Parser(metaclass=_Parser):
prefix += self._prev.text prefix += self._prev.text
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
return exp.Identifier(this=prefix + self._prev.text, quoted=False) quoted = self._prev.token_type == TokenType.STRING
return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
return None return None
def _parse_string(self) -> t.Optional[exp.Expression]: def _parse_string(self) -> t.Optional[exp.Expression]:
@ -3486,6 +3561,11 @@ class Parser(metaclass=_Parser):
def _parse_set(self) -> exp.Expression: def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
self._advance()
return exp.Command(this=self._find_sql(start, self._prev))
def _find_parser( def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]: ) -> t.Optional[t.Callable]:

View file

@ -11,6 +11,7 @@ from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.types import StructType
from sqlglot.dialects.dialect import DialectType
ColumnMapping = t.Union[t.Dict, str, StructType, t.List] ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
@ -153,7 +154,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self, self,
schema: t.Optional[t.Dict] = None, schema: t.Optional[t.Dict] = None,
visible: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None, dialect: DialectType = None,
) -> None: ) -> None:
self.dialect = dialect self.dialect = dialect
self.visible = visible or {} self.visible = visible or {}

View file

@ -665,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
"STRING": TokenType.TEXT, "STRING": TokenType.TEXT,
"TEXT": TokenType.TEXT, "TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT, "CLOB": TokenType.TEXT,
"LONGVARCHAR": TokenType.TEXT,
"BINARY": TokenType.BINARY, "BINARY": TokenType.BINARY,
"BLOB": TokenType.VARBINARY, "BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY,

View file

@ -170,7 +170,7 @@ class TestBigQuery(Validator):
"bigquery": "CURRENT_TIMESTAMP()", "bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP()",
"postgres": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP()", "presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()", "hive": "CURRENT_TIMESTAMP()",
"spark": "CURRENT_TIMESTAMP()", "spark": "CURRENT_TIMESTAMP()",
}, },
@ -181,7 +181,7 @@ class TestBigQuery(Validator):
"bigquery": "CURRENT_TIMESTAMP()", "bigquery": "CURRENT_TIMESTAMP()",
"duckdb": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP()",
"postgres": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP",
"presto": "CURRENT_TIMESTAMP()", "presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()", "hive": "CURRENT_TIMESTAMP()",
"spark": "CURRENT_TIMESTAMP()", "spark": "CURRENT_TIMESTAMP()",
}, },

View file

@ -1,6 +1,7 @@
import unittest import unittest
from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
from sqlglot.dialects import Hive
class Validator(unittest.TestCase): class Validator(unittest.TestCase):
@ -67,6 +68,11 @@ class TestDialect(Validator):
self.assertIsNotNone(Dialect.get_or_raise(dialect)) self.assertIsNotNone(Dialect.get_or_raise(dialect))
self.assertIsNotNone(Dialect[dialect.value]) self.assertIsNotNone(Dialect[dialect.value])
def test_get_or_raise(self):
self.assertEqual(Dialect.get_or_raise(Hive), Hive)
self.assertEqual(Dialect.get_or_raise(Hive()), Hive)
self.assertEqual(Dialect.get_or_raise("hive"), Hive)
def test_cast(self): def test_cast(self):
self.validate_all( self.validate_all(
"CAST(a AS TEXT)", "CAST(a AS TEXT)",
@ -280,6 +286,21 @@ class TestDialect(Validator):
write={"oracle": "CAST(a AS NUMBER)"}, write={"oracle": "CAST(a AS NUMBER)"},
) )
def test_if_null(self):
self.validate_all(
"SELECT IFNULL(1, NULL) FROM foo",
write={
"": "SELECT COALESCE(1, NULL) FROM foo",
"redshift": "SELECT COALESCE(1, NULL) FROM foo",
"postgres": "SELECT COALESCE(1, NULL) FROM foo",
"mysql": "SELECT COALESCE(1, NULL) FROM foo",
"duckdb": "SELECT COALESCE(1, NULL) FROM foo",
"spark": "SELECT COALESCE(1, NULL) FROM foo",
"bigquery": "SELECT COALESCE(1, NULL) FROM foo",
"presto": "SELECT COALESCE(1, NULL) FROM foo",
},
)
def test_time(self): def test_time(self):
self.validate_all( self.validate_all(
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
@ -287,10 +308,10 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
}, },
write={ write={
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", "presto": "DATE_PARSE(x, '%Y-%m-%dT%T')",
"drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')", "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
@ -356,7 +377,7 @@ class TestDialect(Validator):
write={ write={
"duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))",
"hive": "UNIX_TIMESTAMP('2020-01-01')", "hive": "UNIX_TIMESTAMP('2020-01-01')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))",
}, },
) )
self.validate_all( self.validate_all(
@ -418,7 +439,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"UNIX_TO_STR(x, y)", "UNIX_TO_STR(x, y)",
write={ write={
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", "duckdb": "STRFTIME(TO_TIMESTAMP(x), y)",
"hive": "FROM_UNIXTIME(x, y)", "hive": "FROM_UNIXTIME(x, y)",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)",
"starrocks": "FROM_UNIXTIME(x, y)", "starrocks": "FROM_UNIXTIME(x, y)",
@ -427,7 +448,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"UNIX_TO_TIME(x)", "UNIX_TO_TIME(x)",
write={ write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "duckdb": "TO_TIMESTAMP(x)",
"hive": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)",
"oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)", "oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
"postgres": "TO_TIMESTAMP(x)", "postgres": "TO_TIMESTAMP(x)",
@ -438,7 +459,7 @@ class TestDialect(Validator):
self.validate_all( self.validate_all(
"UNIX_TO_TIME_STR(x)", "UNIX_TO_TIME_STR(x)",
write={ write={
"duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)", "duckdb": "CAST(TO_TIMESTAMP(x) AS TEXT)",
"hive": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)",
"presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)", "presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)",
}, },
@ -575,10 +596,10 @@ class TestDialect(Validator):
}, },
write={ write={
"drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')", "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%T')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)",
"spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')",
}, },
) )
@ -709,6 +730,7 @@ class TestDialect(Validator):
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
}, },
) )
@ -1381,3 +1403,21 @@ SELECT
"spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name", "spark": "MERGE INTO a AS b USING c AS d ON b.id = d.id WHEN MATCHED AND EXISTS(SELECT b.name EXCEPT SELECT d.name) THEN UPDATE SET b.name = d.name",
}, },
) )
def test_substring(self):
self.validate_all(
"SUBSTR('123456', 2, 3)",
write={
"bigquery": "SUBSTR('123456', 2, 3)",
"oracle": "SUBSTR('123456', 2, 3)",
"postgres": "SUBSTR('123456', 2, 3)",
},
)
self.validate_all(
"SUBSTRING('123456', 2, 3)",
write={
"bigquery": "SUBSTRING('123456', 2, 3)",
"oracle": "SUBSTR('123456', 2, 3)",
"postgres": "SUBSTRING('123456' FROM 2 FOR 3)",
},
)

View file

@ -22,7 +22,7 @@ class TestDuckDB(Validator):
"EPOCH_MS(x)", "EPOCH_MS(x)",
write={ write={
"bigquery": "UNIX_TO_TIME(x / 1000)", "bigquery": "UNIX_TO_TIME(x / 1000)",
"duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))", "duckdb": "TO_TIMESTAMP(x / 1000)",
"presto": "FROM_UNIXTIME(x / 1000)", "presto": "FROM_UNIXTIME(x / 1000)",
"spark": "FROM_UNIXTIME(x / 1000)", "spark": "FROM_UNIXTIME(x / 1000)",
}, },
@ -41,7 +41,7 @@ class TestDuckDB(Validator):
"STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
write={ write={
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
}, },
) )
@ -58,9 +58,10 @@ class TestDuckDB(Validator):
self.validate_all( self.validate_all(
"TO_TIMESTAMP(x)", "TO_TIMESTAMP(x)",
write={ write={
"duckdb": "CAST(x AS TIMESTAMP)", "bigquery": "UNIX_TO_TIME(x)",
"presto": "CAST(x AS TIMESTAMP)", "duckdb": "TO_TIMESTAMP(x)",
"hive": "CAST(x AS TIMESTAMP)", "presto": "FROM_UNIXTIME(x)",
"hive": "FROM_UNIXTIME(x)",
}, },
) )
self.validate_all( self.validate_all(
@ -334,6 +335,14 @@ class TestDuckDB(Validator):
}, },
) )
self.validate_all(
"cast([[1]] as int[][])",
write={
"duckdb": "CAST(LIST_VALUE(LIST_VALUE(1)) AS INT[][])",
"spark": "CAST(ARRAY(ARRAY(1)) AS ARRAY<ARRAY<INT>>)",
},
)
def test_bool_or(self): def test_bool_or(self):
self.validate_all( self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",

View file

@ -259,7 +259,7 @@ class TestHive(Validator):
self.validate_all( self.validate_all(
"""from_unixtime(x, "yyyy-MM-dd'T'HH")""", """from_unixtime(x, "yyyy-MM-dd'T'HH")""",
write={ write={
"duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')", "duckdb": "STRFTIME(TO_TIMESTAMP(x), '%Y-%m-%d''T''%H')",
"presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')",
"hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", "hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
"spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", "spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')",
@ -269,7 +269,7 @@ class TestHive(Validator):
"DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')",
write={ write={
"duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')", "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')", "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %T')",
"hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", "spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')",
}, },
@ -308,7 +308,7 @@ class TestHive(Validator):
"UNIX_TIMESTAMP(x)", "UNIX_TIMESTAMP(x)",
write={ write={
"duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))", "duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))",
"presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))", "presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %T'))",
"hive": "UNIX_TIMESTAMP(x)", "hive": "UNIX_TIMESTAMP(x)",
"spark": "UNIX_TIMESTAMP(x)", "spark": "UNIX_TIMESTAMP(x)",
"": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')", "": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",

View file

@ -195,6 +195,26 @@ class TestMySQL(Validator):
) )
def test_mysql(self): def test_mysql(self):
self.validate_all(
"SELECT a FROM tbl FOR UPDATE",
write={
"": "SELECT a FROM tbl",
"mysql": "SELECT a FROM tbl FOR UPDATE",
"oracle": "SELECT a FROM tbl FOR UPDATE",
"postgres": "SELECT a FROM tbl FOR UPDATE",
"tsql": "SELECT a FROM tbl FOR UPDATE",
},
)
self.validate_all(
"SELECT a FROM tbl FOR SHARE",
write={
"": "SELECT a FROM tbl",
"mysql": "SELECT a FROM tbl FOR SHARE",
"oracle": "SELECT a FROM tbl FOR SHARE",
"postgres": "SELECT a FROM tbl FOR SHARE",
"tsql": "SELECT a FROM tbl FOR SHARE",
},
)
self.validate_all( self.validate_all(
"GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)",
write={ write={

View file

@ -112,6 +112,22 @@ class TestPostgres(Validator):
self.validate_identity("x ~ 'y'") self.validate_identity("x ~ 'y'")
self.validate_identity("x ~* 'y'") self.validate_identity("x ~* 'y'")
self.validate_all(
"GENERATE_SERIES(a, b, ' 2 days ')",
write={
"postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)",
"presto": "SEQUENCE(a, b, INTERVAL '2' days)",
"trino": "SEQUENCE(a, b, INTERVAL '2' days)",
},
)
self.validate_all(
"GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')",
write={
"postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)",
"presto": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
"trino": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
},
)
self.validate_all( self.validate_all(
"END WORK AND NO CHAIN", "END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"}, write={"postgres": "COMMIT AND NO CHAIN"},
@ -249,7 +265,7 @@ class TestPostgres(Validator):
) )
self.validate_all( self.validate_all(
"'[1,2,3]'::json->2", "'[1,2,3]'::json->2",
write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"}, write={"postgres": "CAST('[1,2,3]' AS JSON) -> 2"},
) )
self.validate_all( self.validate_all(
"""'{"a":1,"b":2}'::json->'b'""", """'{"a":1,"b":2}'::json->'b'""",
@ -265,7 +281,7 @@ class TestPostgres(Validator):
) )
self.validate_all( self.validate_all(
"""'[1,2,3]'::json->>2""", """'[1,2,3]'::json->>2""",
write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"}, write={"postgres": "CAST('[1,2,3]' AS JSON) ->> 2"},
) )
self.validate_all( self.validate_all(
"""'{"a":1,"b":2}'::json->>'b'""", """'{"a":1,"b":2}'::json->>'b'""",

View file

@ -111,7 +111,7 @@ class TestPresto(Validator):
"DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')",
write={ write={
"duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
"spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')",
}, },
@ -120,7 +120,7 @@ class TestPresto(Validator):
"DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')",
write={ write={
"duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')", "duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')",
"presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", "presto": "DATE_PARSE(x, '%Y-%m-%d %T')",
"hive": "CAST(x AS TIMESTAMP)", "hive": "CAST(x AS TIMESTAMP)",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')",
}, },
@ -134,6 +134,12 @@ class TestPresto(Validator):
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')",
}, },
) )
self.validate_all(
"DATE_FORMAT(x, '%T')",
write={
"hive": "DATE_FORMAT(x, 'HH:mm:ss')",
},
)
self.validate_all( self.validate_all(
"DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
write={ write={
@ -146,7 +152,7 @@ class TestPresto(Validator):
self.validate_all( self.validate_all(
"FROM_UNIXTIME(x)", "FROM_UNIXTIME(x)",
write={ write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", "duckdb": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)", "presto": "FROM_UNIXTIME(x)",
"hive": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)",
"spark": "FROM_UNIXTIME(x)", "spark": "FROM_UNIXTIME(x)",
@ -177,11 +183,51 @@ class TestPresto(Validator):
self.validate_all( self.validate_all(
"NOW()", "NOW()",
write={ write={
"presto": "CURRENT_TIMESTAMP()", "presto": "CURRENT_TIMESTAMP",
"hive": "CURRENT_TIMESTAMP()", "hive": "CURRENT_TIMESTAMP()",
}, },
) )
self.validate_all(
"DAY_OF_WEEK(timestamp '2012-08-08 01:00')",
write={
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))",
},
)
self.validate_all(
"DAY_OF_MONTH(timestamp '2012-08-08 01:00')",
write={
"spark": "DAYOFMONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
"presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))",
},
)
self.validate_all(
"DAY_OF_YEAR(timestamp '2012-08-08 01:00')",
write={
"spark": "DAYOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
"presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
},
)
self.validate_all(
"WEEK_OF_YEAR(timestamp '2012-08-08 01:00')",
write={
"spark": "WEEKOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
"presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))",
},
)
self.validate_all(
"SELECT timestamp '2012-10-31 00:00' AT TIME ZONE 'America/Sao_Paulo'",
write={
"spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')",
"presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'",
},
)
def test_ddl(self): def test_ddl(self):
self.validate_all( self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
@ -314,6 +360,11 @@ class TestPresto(Validator):
def test_presto(self): def test_presto(self):
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
self.validate_identity("SELECT * FROM (VALUES (1))")
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
self.validate_all( self.validate_all(
'SELECT a."b" FROM "foo"', 'SELECT a."b" FROM "foo"',
write={ write={
@ -455,10 +506,6 @@ class TestPresto(Validator):
"spark": UnsupportedError, "spark": UnsupportedError,
}, },
) )
self.validate_identity("SELECT * FROM (VALUES (1))")
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
def test_encode_decode(self): def test_encode_decode(self):
self.validate_all( self.validate_all(
@ -529,3 +576,27 @@ class TestPresto(Validator):
"presto": "FROM_HEX(x)", "presto": "FROM_HEX(x)",
}, },
) )
def test_json(self):
self.validate_all(
"SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))",
write={
"spark": "SELECT FROM_JSON('[1,23,456]', 'ARRAY<INT>')",
"presto": "SELECT CAST(CAST('[1,23,456]' AS JSON) AS ARRAY(INTEGER))",
},
)
self.validate_all(
"""SELECT CAST(JSON '{"k1":1,"k2":23,"k3":456}' AS MAP(VARCHAR, INTEGER))""",
write={
"spark": 'SELECT FROM_JSON(\'{"k1":1,"k2":23,"k3":456}\', \'MAP<STRING, INT>\')',
"presto": 'SELECT CAST(CAST(\'{"k1":1,"k2":23,"k3":456}\' AS JSON) AS MAP(VARCHAR, INTEGER))',
},
)
self.validate_all(
"SELECT CAST(ARRAY [1, 23, 456] AS JSON)",
write={
"spark": "SELECT TO_JSON(ARRAY(1, 23, 456))",
"presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)",
},
)

View file

@ -212,6 +212,17 @@ TBLPROPERTIES (
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_all(
"AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
write={
"trino": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
"duckdb": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
"hive": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
"presto": "REDUCE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
"spark": "AGGREGATE(my_arr, 0, (acc, x) -> acc + x, s -> s * 2)",
},
)
self.validate_all( self.validate_all(
"TRIM('SL', 'SSparkSQLS')", write={"spark": "TRIM('SL' FROM 'SSparkSQLS')"} "TRIM('SL', 'SSparkSQLS')", write={"spark": "TRIM('SL' FROM 'SSparkSQLS')"}
) )

View file

@ -92,3 +92,9 @@ class TestSQLite(Validator):
"sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks" "sqlite": "SELECT FIRST_VALUE(Name) OVER (PARTITION BY AlbumId ORDER BY Bytes DESC) AS LargestTrack FROM tracks"
}, },
) )
def test_longvarchar_dtype(self):
self.validate_all(
"CREATE TABLE foo (bar LONGVARCHAR)",
write={"sqlite": "CREATE TABLE foo (bar TEXT)"},
)

View file

@ -21,3 +21,6 @@ class TestTeradata(Validator):
"mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1", "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
}, },
) )
def test_create(self):
self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)")

View file

@ -161,6 +161,7 @@ SELECT 1 FROM test
SELECT * FROM a, b, (SELECT 1) AS c SELECT * FROM a, b, (SELECT 1) AS c
SELECT a FROM test SELECT a FROM test
SELECT 1 AS filter SELECT 1 AS filter
SELECT 1 AS "quoted alias"
SELECT SUM(x) AS filter SELECT SUM(x) AS filter
SELECT 1 AS range FROM test SELECT 1 AS range FROM test
SELECT 1 AS count FROM test SELECT 1 AS count FROM test
@ -264,7 +265,9 @@ SELECT a FROM test GROUP BY GROUPING SETS (x, ())
SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q) SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q)
SELECT a FROM test GROUP BY CUBE (x) SELECT a FROM test GROUP BY CUBE (x)
SELECT a FROM test GROUP BY ROLLUP (x) SELECT a FROM test GROUP BY ROLLUP (x)
SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z) SELECT t.a FROM test AS t GROUP BY ROLLUP (t.x)
SELECT a FROM test GROUP BY GROUPING SETS ((x, y)), ROLLUP (b)
SELECT a FROM test GROUP BY CUBE (x), ROLLUP (x, y, z)
SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test
SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END
SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a'] SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a']
@ -339,7 +342,6 @@ SELECT CAST(a AS ARRAY<INT>) FROM test
SELECT CAST(a AS VARIANT) FROM test SELECT CAST(a AS VARIANT) FROM test
SELECT TRY_CAST(a AS INT) FROM test SELECT TRY_CAST(a AS INT) FROM test
SELECT COALESCE(a, b, c) FROM test SELECT COALESCE(a, b, c) FROM test
SELECT IFNULL(a, b) FROM test
SELECT ANY_VALUE(a) FROM test SELECT ANY_VALUE(a) FROM test
SELECT 1 FROM a JOIN b ON a.x = b.x SELECT 1 FROM a JOIN b ON a.x = b.x
SELECT 1 FROM a JOIN b AS c ON a.x = b.x SELECT 1 FROM a JOIN b AS c ON a.x = b.x
@ -510,6 +512,14 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
CREATE TABLE z (a INT REFERENCES parent(b, c)) CREATE TABLE z (a INT REFERENCES parent(b, c))
CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE NO ACTION)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE CASCADE)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET NULL)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON DELETE SET DEFAULT)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE NO ACTION)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE CASCADE)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET NULL)
CREATE TABLE foo (bar INT REFERENCES baz(baz_id) ON UPDATE SET DEFAULT)
CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
CREATE TABLE asd AS SELECT asd FROM asd WITH DATA CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
@ -526,6 +536,7 @@ CREATE TABLE a, DUAL JOURNAL, DUAL AFTER JOURNAL, MERGEBLOCKRATIO=1 PERCENT, DAT
CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT) CREATE TABLE a, DUAL BEFORE JOURNAL, LOCAL AFTER JOURNAL, MAXIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=AUTOTEMP(c1 INT) (a INT)
CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT) CREATE SET GLOBAL TEMPORARY TABLE a, NO BEFORE JOURNAL, NO AFTER JOURNAL, MINIMUM DATABLOCKSIZE, BLOCKCOMPRESSION=NEVER (a INT)
CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT) CREATE MULTISET VOLATILE TABLE a, NOT LOCAL AFTER JOURNAL, FREESPACE=1 PERCENT, DATABLOCKSIZE=10 BYTES, WITH NO CONCURRENT ISOLATED LOADING FOR ALL (a INT)
CREATE ALGORITHM=UNDEFINED DEFINER=foo@% SQL SECURITY DEFINER VIEW a AS (SELECT a FROM b)
CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE x AS SELECT a FROM d
CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d
CREATE VIEW x AS SELECT a FROM b CREATE VIEW x AS SELECT a FROM b
@ -555,6 +566,7 @@ CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b)
CREATE SCHEMA x CREATE SCHEMA x
CREATE SCHEMA IF NOT EXISTS y CREATE SCHEMA IF NOT EXISTS y
CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END' CREATE PROCEDURE IF NOT EXISTS a.b.c() AS 'DECLARE BEGIN; END'
CREATE OR REPLACE STAGE
DESCRIBE x DESCRIBE x
DROP INDEX a.b.c DROP INDEX a.b.c
DROP FUNCTION a.b.c (INT) DROP FUNCTION a.b.c (INT)

View file

@ -50,6 +50,10 @@ WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.x AS x, cte.y AS y FROM cte AS ct
(SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y)); (SELECT a FROM (SELECT b FROM x)) UNION (SELECT a FROM (SELECT b FROM y));
WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4); WITH cte AS (SELECT b FROM x), cte_2 AS (SELECT a FROM cte AS cte), cte_3 AS (SELECT b FROM y), cte_4 AS (SELECT a FROM cte_3 AS cte_3) (SELECT cte_2.a AS a FROM cte_2 AS cte_2) UNION (SELECT cte_4.a AS a FROM cte_4 AS cte_4);
-- Three unions
SELECT a FROM x UNION ALL SELECT a FROM y UNION ALL SELECT a FROM z;
WITH cte AS (SELECT a FROM x), cte_2 AS (SELECT a FROM y), cte_3 AS (SELECT a FROM z), cte_4 AS (SELECT cte_2.a AS a FROM cte_2 AS cte_2 UNION ALL SELECT cte_3.a AS a FROM cte_3 AS cte_3) SELECT cte.a AS a FROM cte AS cte UNION ALL SELECT cte_4.a AS a FROM cte_4 AS cte_4;
-- Subquery -- Subquery
SELECT a FROM x WHERE b = (SELECT y.c FROM y); SELECT a FROM x WHERE b = (SELECT y.c FROM y);
SELECT a FROM x WHERE b = (SELECT y.c FROM y); SELECT a FROM x WHERE b = (SELECT y.c FROM y);

View file

@ -99,7 +99,7 @@ WITH cte1 AS (
GROUPING SETS ( GROUPING SETS (
a, a,
(b, c) (b, c)
) ),
CUBE ( CUBE (
y, y,
z z

View file

@ -62,6 +62,16 @@ class TestBuild(unittest.TestCase):
lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False), lambda: select("x").from_("tbl").where("x > 0").where("x < 9", append=False),
"SELECT x FROM tbl WHERE x < 9", "SELECT x FROM tbl WHERE x < 9",
), ),
(
lambda: select("x").from_("tbl").where("x > 0").lock(),
"SELECT x FROM tbl WHERE x > 0 FOR UPDATE",
"mysql",
),
(
lambda: select("x").from_("tbl").where("x > 0").lock(update=False),
"SELECT x FROM tbl WHERE x > 0 FOR SHARE",
"postgres",
),
( (
lambda: select("x", "y").from_("tbl").group_by("x"), lambda: select("x", "y").from_("tbl").group_by("x"),
"SELECT x, y FROM tbl GROUP BY x", "SELECT x, y FROM tbl GROUP BY x",

View file

@ -466,6 +466,7 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction) self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
self.assertIsInstance(parse_one("COMMIT"), exp.Commit) self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries)
def test_column(self): def test_column(self):
dot = parse_one("a.b.c") dot = parse_one("a.b.c")
@ -630,6 +631,19 @@ FROM foo""",
FROM foo""", FROM foo""",
) )
def test_to_interval(self):
self.assertEqual(exp.to_interval("1day").sql(), "INTERVAL '1' day")
self.assertEqual(exp.to_interval(" 5 months").sql(), "INTERVAL '5' months")
with self.assertRaises(ValueError):
exp.to_interval("bla")
self.assertEqual(exp.to_interval(exp.Literal.string("1day")).sql(), "INTERVAL '1' day")
self.assertEqual(
exp.to_interval(exp.Literal.string(" 5 months")).sql(), "INTERVAL '5' months"
)
with self.assertRaises(ValueError):
exp.to_interval(exp.Literal.string("bla"))
def test_to_table(self): def test_to_table(self):
table_only = exp.to_table("table_name") table_only = exp.to_table("table_name")
self.assertEqual(table_only.name, "table_name") self.assertEqual(table_only.name, "table_name")

View file

@ -326,12 +326,12 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb")
self.validate( self.validate(
"UNIX_TO_STR(123, 'y')", "UNIX_TO_STR(123, 'y')",
"STRFTIME(TO_TIMESTAMP(CAST(123 AS BIGINT)), 'y')", "STRFTIME(TO_TIMESTAMP(123), 'y')",
write="duckdb", write="duckdb",
) )
self.validate( self.validate(
"UNIX_TO_TIME(123)", "UNIX_TO_TIME(123)",
"TO_TIMESTAMP(CAST(123 AS BIGINT))", "TO_TIMESTAMP(123)",
write="duckdb", write="duckdb",
) )
@ -426,6 +426,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1)
mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1)
def test_identify_lambda(self):
self.validate("x(y -> y)", 'X("y" -> "y")', identify=True)
def test_identity(self): def test_identity(self):
self.assertEqual(transpile("")[0], "") self.assertEqual(transpile("")[0], "")
for sql in load_sql_fixtures("identity.sql"): for sql in load_sql_fixtures("identity.sql"):