1
0
Fork 0

Merging upstream version 18.2.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:58:22 +01:00
parent 985db29269
commit 53cf4a81a6
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
124 changed files with 60313 additions and 50346 deletions

View file

@ -21,10 +21,12 @@ Currently many of the common operations are covered and more functionality will
* Ex: `['cola', 'colb']`
* The lack of types may limit functionality in future releases.
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
* If your output SQL dialect is not Spark, then configure the SparkSession to use that dialect
* Ex: `SparkSession().builder.config("sqlframe.dialect", "bigquery").getOrCreate()`
* See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
* Ex: `.sql(pretty=True, dialect='bigquery')`
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
* Ex: `.sql(pretty=True)`
## Examples
@ -33,6 +35,8 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
dialect = "spark"
sqlglot.schema.add_table(
'employee',
{
@ -41,10 +45,10 @@ sqlglot.schema.add_table(
'lname': 'STRING',
'age': 'INT',
},
dialect="spark",
dialect=dialect,
) # Register the table structure prior to reading from the table
spark = SparkSession()
spark = SparkSession.builder.config("sqlframe.dialect", dialect).getOrCreate()
df = (
spark
@ -53,7 +57,7 @@ df = (
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
)
print(df.sql(pretty=True)) # Spark will be the dialect used by default
print(df.sql(pretty=True))
```
```sparksql
@ -81,7 +85,7 @@ class ExternalSchema(Schema):
sqlglot.schema = ExternalSchema()
spark = SparkSession()
spark = SparkSession() # Spark will be used by default is not specific in SparkSession config
df = (
spark
@ -119,11 +123,14 @@ schema = types.StructType([
])
sql_statements = (
SparkSession()
SparkSession
.builder
.config("sqlframe.dialect", "bigquery")
.getOrCreate()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
.sql(dialect="bigquery")
.sql()
)
result = None
@ -166,11 +173,14 @@ schema = types.StructType([
])
sql_statements = (
SparkSession()
SparkSession
.builder
.config("sqlframe.dialect", "snowflake")
.getOrCreate()
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("lname")).alias("num_employees"))
.sql(dialect="snowflake")
.sql()
)
try:
@ -210,7 +220,7 @@ sql_statements = (
.createDataFrame(data, schema)
.groupBy(F.col("age"))
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
.sql(dialect="spark")
.sql()
)
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()

View file

@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
@ -15,19 +14,20 @@ if t.TYPE_CHECKING:
class Column:
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
from sqlglot.dataframe.sql.session import SparkSession
if isinstance(expression, Column):
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
expression = sqlglot.maybe_parse(expression, dialect="spark")
elif not isinstance(expression, exp.Column):
expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
SparkSession().dialect.normalize_identifier, copy=False
)
if expression is None:
raise ValueError(f"Could not parse {expression}")
if isinstance(expression, exp.Column):
expression.transform(Spark.normalize_identifier, copy=False)
self.expression: exp.Expression = expression
self.expression: exp.Expression = expression # type: ignore
def __repr__(self):
return repr(self.expression)
@ -207,7 +207,9 @@ class Column:
return Column(expression)
def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs})
from sqlglot.dataframe.sql.session import SparkSession
return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
new_expression = exp.alias_(self.column_expression, name)
@ -264,9 +266,11 @@ class Column:
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string
"""
from sqlglot.dataframe.sql.session import SparkSession
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value

View file

@ -1,12 +1,13 @@
from __future__ import annotations
import functools
import logging
import typing as t
import zlib
from copy import copy
import sqlglot
from sqlglot import expressions as exp
from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.group import GroupedData
@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@ -27,7 +29,9 @@ if t.TYPE_CHECKING:
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dialects.dialect import DialectType
logger = logging.getLogger("sqlglot")
JOIN_HINTS = {
"BROADCAST",
@ -264,7 +268,9 @@ class DataFrame:
@classmethod
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
value = expression.sql(dialect="spark").encode("utf-8")
from sqlglot.dataframe.sql.session import SparkSession
value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
def _get_select_expressions(
@ -291,7 +297,15 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
def sql(
self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
) -> t.List[str]:
from sqlglot.dataframe.sql.session import SparkSession
if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
logger.warning(
f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
)
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
@ -299,7 +313,10 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
select_expression = t.cast(exp.Select, optimize_func(select_expression))
quote_identifiers(select_expression)
select_expression = t.cast(
exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
)
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@ -313,10 +330,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
expression.alias_or_name: expression.type.sql("spark")
expression.alias_or_name: expression.type.sql(
dialect=SparkSession().dialect
)
for expression in select_expression.expressions
},
dialect="spark",
dialect=SparkSession().dialect,
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
@ -345,7 +364,8 @@ class DataFrame:
output_expressions.append(expression)
return [
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame:

View file

@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
if ignorenulls is not None:
return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
return Column.invoke_anonymous_function(col, "FIRST")
return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
def grouping_id(*cols: ColumnOrName) -> Column:
@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
if ignorenulls is not None:
return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
return Column.invoke_anonymous_function(col, "LAST")
return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
def monotonically_increasing_id() -> Column:

View file

@ -5,7 +5,6 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
Spark.normalize_identifier(identifier)
identifier.transform(spark.dialect.normalize_identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)

View file

@ -4,7 +4,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
@ -18,15 +17,25 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
sqlglot.schema.add_table(tableName, dialect="spark")
sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
return DataFrame(
self.spark,
exp.Select()
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.from_(
exp.to_table(tableName, dialect=SparkSession().dialect).transform(
SparkSession().dialect.normalize_identifier
)
)
.select(
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
*(
column
for column in sqlglot.schema.column_names(
tableName, dialect=SparkSession().dialect
)
)
),
)
@ -63,6 +72,8 @@ class DataFrameWriter:
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
from sqlglot.dataframe.sql.session import SparkSession
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
@ -71,7 +82,9 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
columns = sqlglot.schema.column_names(
tableName, only_visible=True, dialect=SparkSession().dialect
)
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)

View file

@ -5,31 +5,35 @@ import uuid
from collections import defaultdict
import sqlglot
from sqlglot import expressions as exp
from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
known_ids: t.ClassVar[t.Set[str]] = set()
known_branch_ids: t.ClassVar[t.Set[str]] = set()
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
DEFAULT_DIALECT = "spark"
_instance = None
def __init__(self):
self.incrementing_id = 1
if not hasattr(self, "known_ids"):
self.known_ids = set()
self.known_branch_ids = set()
self.known_sequence_ids = set()
self.name_to_sequence_id_mapping = defaultdict(list)
self.incrementing_id = 1
self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
def __getattr__(self, name: str) -> SparkSession:
return self
def __call__(self, *args, **kwargs) -> SparkSession:
return self
def __new__(cls, *args, **kwargs) -> SparkSession:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@property
def read(self) -> DataFrameReader:
@ -101,7 +105,7 @@ class SparkSession:
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
expression = sqlglot.parse_one(sqlQuery, read="spark")
expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
@ -149,3 +153,38 @@ class SparkSession:
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
class Builder:
SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
def __init__(self):
self.dialect = "spark"
def __getattr__(self, item) -> SparkSession.Builder:
return self
def __call__(self, *args, **kwargs):
return self
def config(
self,
key: t.Optional[str] = None,
value: t.Optional[t.Any] = None,
*,
map: t.Optional[t.Dict[str, t.Any]] = None,
**kwargs: t.Any,
) -> SparkSession.Builder:
if key == self.SQLFRAME_DIALECT_KEY:
self.dialect = value
elif map and self.SQLFRAME_DIALECT_KEY in map:
self.dialect = map[self.SQLFRAME_DIALECT_KEY]
return self
def getOrCreate(self) -> SparkSession:
spark = SparkSession()
spark.dialect = Dialect.get_or_raise(self.dialect)()
return spark
@classproperty
def builder(cls) -> Builder:
return cls.Builder()

View file

@ -48,7 +48,9 @@ class WindowSpec:
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
return self.expression.sql(dialect="spark", **kwargs)
from sqlglot.dataframe.sql.session import SparkSession
return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot")
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression):
) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
@ -40,7 +41,7 @@ def _date_add_sql(
return func
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
return f"RETURNS {this}"
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
from sqlglot.optimizer.scope import Scope
from sqlglot.optimizer.scope import find_all_in_scope
if isinstance(expression, exp.Select):
for unnest in expression.find_all(exp.Unnest):
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
for column in Scope(expression).find_all(exp.Column):
if column.table == unnest.alias:
column.set("table", None)
unnest_aliases = {
unnest.alias
for unnest in find_all_in_scope(expression, exp.Unnest)
if isinstance(unnest.parent, (exp.From, exp.Join))
}
if unnest_aliases:
for column in expression.find_all(exp.Column):
if column.table in unnest_aliases:
column.set("table", None)
elif column.db in unnest_aliases:
column.set("db", None)
return expression
@ -261,6 +268,7 @@ class BigQuery(Dialect):
"TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
KEYWORDS.pop("DIV")
@ -270,6 +278,8 @@ class BigQuery(Dialect):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": _parse_date,
@ -299,6 +309,8 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
"SPLIT": lambda args: exp.Split(
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
this=seq_get(args, 0),
@ -346,7 +358,7 @@ class BigQuery(Dialect):
}
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
this = super()._parse_table_part(schema=schema)
this = super()._parse_table_part(schema=schema) or self._parse_number()
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
if isinstance(this, exp.Identifier):
@ -356,6 +368,17 @@ class BigQuery(Dialect):
table_name += f"-{self._prev.text}"
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
elif isinstance(this, exp.Literal):
table_name = this.name
if (
self._curr
and self._prev.end == self._curr.start - 1
and self._parse_var(any_token=True)
):
table_name += self._prev.text
this = exp.Identifier(this=table_name, quoted=True)
return this
@ -374,6 +397,27 @@ class BigQuery(Dialect):
return table
def _parse_json_object(self) -> exp.JSONObject:
json_object = super()._parse_json_object()
array_kv_pair = seq_get(json_object.expressions, 0)
# Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation
# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2
if (
array_kv_pair
and isinstance(array_kv_pair.this, exp.Array)
and isinstance(array_kv_pair.expression, exp.Array)
):
keys = array_kv_pair.this.expressions
values = array_kv_pair.expression.expressions
json_object.set(
"expressions",
[exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)],
)
return json_object
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
@ -383,6 +427,7 @@ class BigQuery(Dialect):
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -405,6 +450,7 @@ class BigQuery(Dialect):
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@ -428,6 +474,9 @@ class BigQuery(Dialect):
_alias_ordered_group,
]
),
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@ -591,6 +640,13 @@ class BigQuery(Dialect):
return super().attimezone_sql(expression)
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals
if expression.is_type("json"):
return f"JSON {self.sql(expression, 'this')}"
return super().cast_sql(expression, safe_prefix=safe_prefix)
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
@ -630,3 +686,9 @@ class BigQuery(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))
def version_sql(self, expression: exp.Version) -> str:
if expression.name == "TIMESTAMP":
expression = expression.copy()
expression.set("this", "SYSTEM_TIME")
return super().version_sql(expression)

View file

@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.errors import ParseError
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@ -63,9 +64,23 @@ class ClickHouse(Dialect):
}
class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
@ -147,7 +162,7 @@ class ClickHouse(Dialect):
this = self._parse_id_var()
self._match(TokenType.COLON)
kind = self._parse_types(check_func=False) or (
kind = self._parse_types(check_func=False, allow_identifiers=False) or (
self._match_text_seq("IDENTIFIER") and "Identifier"
)
@ -249,7 +264,7 @@ class ClickHouse(Dialect):
def _parse_func_params(
self, this: t.Optional[exp.Func] = None
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
) -> t.Optional[t.List[exp.Expression]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
@ -267,9 +282,7 @@ class ClickHouse(Dialect):
return self.expression(exp.Quantile, this=params[0], quantile=this)
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
def _parse_wrapped_id_vars(
self, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return super()._parse_wrapped_id_vars(optional=True)
def _parse_primary_key(
@ -292,9 +305,22 @@ class ClickHouse(Dialect):
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
exp.DataType.Type.LONGBLOB: "String",
exp.DataType.Type.LONGTEXT: "String",
exp.DataType.Type.MEDIUMBLOB: "String",
exp.DataType.Type.MEDIUMTEXT: "String",
exp.DataType.Type.TEXT: "String",
exp.DataType.Type.VARBINARY: "String",
exp.DataType.Type.VARCHAR: "String",
}
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
**STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
@ -328,6 +354,12 @@ class ClickHouse(Dialect):
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@ -364,6 +396,16 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
def datatype_sql(self, expression: exp.DataType) -> str:
# String is the standard ClickHouse type, every other variant is just an alias.
# Additionally, any supplied length parameter will be ignored.
#
# https://clickhouse.com/docs/en/sql-reference/data-types/string
if expression.this in self.STRING_TYPE_MAPPING:
return "String"
return super().datatype_sql(expression)
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
expression = expression.copy()

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
from sqlglot.tokens import TokenType
@ -28,6 +28,19 @@ class Databricks(Spark):
**Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.DatetimeAdd: lambda self, e: self.func(
"TIMESTAMPADD", e.text("unit"), e.expression, e.this
),
exp.DatetimeSub: lambda self, e: self.func(
"TIMESTAMPADD",
e.text("unit"),
exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
e.this,
),
exp.DatetimeDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
),
exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[

View file

@ -109,8 +109,7 @@ class _Dialect(type):
for k, v in vars(klass).items()
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
},
"STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
"TOKENIZER_CLASS": klass.tokenizer_class,
}
if enum not in ("", "bigquery"):
@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql(
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression)}]"
return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
return f"{this}.{struct_key}"
return (
f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
)
def var_map_sql(
@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
return self.func("MAX", expression.this)
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"

View file

@ -37,7 +37,6 @@ class Doris(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"

View file

@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import (
)
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
@ -73,7 +73,6 @@ class Drill(Dialect):
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
@ -81,6 +80,7 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -95,6 +95,7 @@ class Drill(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,

View file

@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
inline_array_sql,
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
@ -30,13 +31,13 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat
# BigQuery -> DuckDB conversion for the DATE function
def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
result = f"CAST({self.sql(expression, 'this')} AS DATE)"
zone = self.sql(expression, "zone")
@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
return result
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
return f"CAST({sql} AS TEXT)"
@ -134,6 +135,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
SUPPORTS_USER_DEFINED_TYPES = False
BITWISE = {
**parser.Parser.BITWISE,
@ -183,18 +185,12 @@ class DuckDB(Dialect):
),
}
TYPE_TOKENS = {
*parser.Parser.TYPE_TOKENS,
TokenType.UBIGINT,
TokenType.UINT,
TokenType.USMALLINT,
TokenType.UTINYINT,
}
def _parse_types(
self, check_func: bool = False, schema: bool = False
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func, schema=schema)
this = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
# See: https://duckdb.org/docs/sql/data_types/numeric
@ -207,6 +203,9 @@ class DuckDB(Dialect):
return this
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
return self._parse_field_def()
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
@ -219,13 +218,14 @@ class DuckDB(Dialect):
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
else rename_func("LIST_VALUE")(self, e),
else inline_array_sql(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),

View file

@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
return self.func(func, expression.this, modified_increment)
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = TIME_DIFF_FACTOR.get(unit)
@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
# Since FROM_JSON requires a nested type, we always wrap the json string with
@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
return self.func("TO_JSON", this, expression.args.get("options"))
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
return f"CAST({this} AS DATE)"
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
return f"CAST({this} AS TIMESTAMP)"
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@ -206,6 +206,8 @@ class Hive(Dialect):
"MSCK REPAIR": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
"VERSION AS OF": TokenType.VERSION_SNAPSHOT,
}
NUMERIC_LITERALS = {
@ -220,6 +222,7 @@ class Hive(Dialect):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -257,6 +260,11 @@ class Hive(Dialect):
),
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
"STR_TO_MAP": lambda args: exp.StrToMap(
this=seq_get(args, 0),
pair_delim=seq_get(args, 1) or exp.Literal.string(","),
key_value_delim=seq_get(args, 2) or exp.Literal.string(":"),
),
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
@ -313,7 +321,7 @@ class Hive(Dialect):
)
def _parse_types(
self, check_func: bool = False, schema: bool = False
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
"""
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
@ -333,7 +341,9 @@ class Hive(Dialect):
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
"""
this = super()._parse_types(check_func=check_func, schema=schema)
this = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if this and not schema:
return this.transform(
@ -345,6 +355,16 @@ class Hive(Dialect):
return this
def _parse_partition_and_order(
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
self._parse_csv(self._parse_conjunction)
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
else [],
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@ -354,6 +374,7 @@ class Hive(Dialect):
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -376,6 +397,7 @@ class Hive(Dialect):
]
),
exp.Property: _property_sql,
exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
@ -402,6 +424,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
exp.NotNullColumnConstraint: lambda self, e: ""
if e.args.get("allow_null")
else "NOT NULL",
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
@ -472,7 +497,7 @@ class Hive(Dialect):
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
elif expression.is_type("float"):
size_expression = expression.find(exp.DataTypeSize)
size_expression = expression.find(exp.DataTypeParam)
if size_expression:
size = int(size_expression.name)
expression = (
@ -480,3 +505,7 @@ class Hive(Dialect):
)
return super().datatype_sql(expression)
def version_sql(self, expression: exp.Version) -> str:
sql = super().version_sql(expression)
return sql.replace("FOR ", "", 1)

View file

@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
datestrtodate_sql,
format_time_lambda,
json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
return _parse
def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
@ -129,6 +133,7 @@ class MySQL(Dialect):
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"MEDIUMINT": TokenType.MEDIUMINT,
"MEMBER OF": TokenType.MEMBER_OF,
"SEPARATOR": TokenType.SEPARATOR,
"START": TokenType.BEGIN,
@ -136,6 +141,7 @@ class MySQL(Dialect):
"SIGNED INTEGER": TokenType.BIGINT,
"UNSIGNED": TokenType.UBIGINT,
"UNSIGNED INTEGER": TokenType.UBIGINT,
"YEAR": TokenType.YEAR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@ -185,6 +191,8 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
SUPPORTS_USER_DEFINED_TYPES = False
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
@ -492,6 +500,17 @@ class MySQL(Dialect):
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
def _parse_type(self) -> t.Optional[exp.Expression]:
# mysql binary is special and can work anywhere, even in order by operations
# it operates like a no paren func
if self._match(TokenType.BINARY, advance=False):
data_type = self._parse_types(check_func=True, allow_identifiers=False)
if isinstance(data_type, exp.DataType):
return self.expression(exp.Cast, this=self._parse_column(), to=data_type)
return super()._parse_type()
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@ -500,6 +519,7 @@ class MySQL(Dialect):
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@ -515,6 +535,7 @@ class MySQL(Dialect):
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
@ -524,6 +545,7 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),

View file

@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
if self._match_text_seq("COLUMNS"):
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
columns = self._parse_csv(self._parse_field_def)
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
@ -78,6 +78,10 @@ class Oracle(Dialect):
)
}
# SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT ..
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@ -129,7 +133,6 @@ class Oracle(Dialect):
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.ILike: no_ilike_sql,
exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
@ -162,7 +165,7 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
VAR_SINGLE_TOKENS = {"@"}
VAR_SINGLE_TOKENS = {"@", "$", "#"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = {
}
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
expression = expression.copy()
this = self.sql(expression, "this")
@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"CAST({unit} AS BIGINT)"
def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
return f"SUBSTRING({this}{from_part}{for_part})"
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}{order})"
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
@ -254,6 +255,7 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
"@@": TokenType.DAT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
"BEGIN": TokenType.COMMAND,
@ -273,6 +275,18 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
"REGDICTIONARY": TokenType.OBJECT_IDENTIFIER,
"REGNAMESPACE": TokenType.OBJECT_IDENTIFIER,
"REGOPER": TokenType.OBJECT_IDENTIFIER,
"REGOPERATOR": TokenType.OBJECT_IDENTIFIER,
"REGPROC": TokenType.OBJECT_IDENTIFIER,
"REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
"REGROLE": TokenType.OBJECT_IDENTIFIER,
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
}
SINGLE_TOKENS = {
@ -312,6 +326,9 @@ class Postgres(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.DAT: lambda self, this: self.expression(
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
@ -343,6 +360,7 @@ class Postgres(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@ -357,6 +375,8 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
@ -416,3 +436,9 @@ class Postgres(Dialect):
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
this = self.sql(expression, "this")
expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions]
sql = " OR ".join(expressions)
return f"({sql})" if len(expressions) > 1 else sql

View file

@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
return self.sql(
@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -
return self.lateral_sql(expression)
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
return self.func("ARRAY_SORT", expression.this, comparator)
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
return self.schema_sql(expression)
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
return expression
def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
"""
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
they're converted into an ARBITRARY call.
Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions
"""
if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize):
return self.function_fallback_sql(expression)
return rename_func("ARBITRARY")(self, expression)
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@ -178,6 +192,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARBITRARY": exp.AnyValue.from_arg_list,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
@ -205,7 +220,14 @@ class Presto(Dialect):
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
"REGEXP_REPLACE": lambda args: exp.RegexpReplace(
this=seq_get(args, 0),
expression=seq_get(args, 1),
replacement=seq_get(args, 2) or exp.Literal.string(""),
),
"ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
@ -225,6 +247,7 @@ class Presto(Dialect):
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@ -242,10 +265,13 @@ class Presto(Dialect):
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.DATETIME64: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@ -268,15 +294,23 @@ class Presto(Dialect):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
e.expression * -1,
e.this,
),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Last: _first_last_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@ -301,8 +335,10 @@ class Presto(Dialect):
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: rename_func("ROW"),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,

View file

@ -13,7 +13,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
@ -37,6 +37,8 @@ class Redshift(Postgres):
}
class Parser(Postgres.Parser):
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
@ -55,9 +57,11 @@ class Redshift(Postgres):
}
def _parse_types(
self, check_func: bool = False, schema: bool = False
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func, schema=schema)
this = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if (
isinstance(this, exp.DataType)
@ -100,6 +104,7 @@ class Redshift(Postgres):
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@ -142,6 +147,9 @@ class Redshift(Postgres):
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
# Redshift supports ANY_VALUE(..)
TRANSFORMS.pop(exp.AnyValue)
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:

View file

@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) ->
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return "ARRAY"
elif expression.is_type("map"):
@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
flag = expression.text("flag")
if "i" not in flag:
flag += "i"
return self.func(
"REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag)
)
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
return regexp_replace
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]:
def _parse(self: Snowflake.Parser) -> exp.Show:
return self._parse_show_snowflake(*args, **kwargs)
return _parse
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@ -216,6 +234,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -230,6 +249,7 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
@ -250,11 +270,6 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@ -281,6 +296,16 @@ class Snowflake(Dialect):
),
}
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
}
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
}
def _parse_id_var(
self,
any_token: bool = True,
@ -296,8 +321,24 @@ class Snowflake(Dialect):
return super()._parse_id_var(any_token=any_token, tokens=tokens)
def _parse_show_snowflake(self, this: str) -> exp.Show:
scope = None
scope_kind = None
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
elif self._match_set(self.DB_CREATABLES):
scope_kind = self._prev.text
if self._curr:
scope = self._parse_table()
elif self._curr:
scope_kind = "TABLE"
scope = self._parse_table()
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
RAW_STRINGS = ["$$"]
@ -331,6 +372,8 @@ class Snowflake(Dialect):
VAR_SINGLE_TOKENS = {"$"}
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
@ -355,6 +398,7 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@ -362,6 +406,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
@ -373,6 +418,7 @@ class Snowflake(Dialect):
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.Stuff: rename_func("INSERT"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@ -403,6 +449,16 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def show_sql(self, expression: exp.Show) -> str:
scope = self.sql(expression, "scope")
scope = f" {scope}" if scope else ""
scope_kind = self.sql(expression, "scope_kind")
if scope_kind:
scope_kind = f" IN {scope_kind}"
return f"SHOW {expression.name}{scope_kind}{scope}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
# generate default values as necessary to ensure the transpilation is correct
@ -436,7 +492,9 @@ class Snowflake(Dialect):
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}"
expressions = self.expressions(expression, flat=True)
expressions = f" {expressions}" if expressions else ""
return f"DESCRIBE{kind}{this}{expressions}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint

View file

@ -38,9 +38,15 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": lambda args: exp.AnyValue(
this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
),
"DATEDIFF": _parse_datediff,
}
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("ANY_VALUE")
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
@ -56,9 +62,13 @@ class Spark(Spark2):
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")

View file

@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
return create_with_partitions_sql(self, e)
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
keys = self.sql(expression.args["keys"])
values = self.sql(expression.args["values"])
return f"MAP_FROM_ARRAYS({keys}, {values})"
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
keys = expression.args.get("keys")
values = expression.args.get("values")
if not keys or not values:
return "MAP()"
return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.DATE_FORMAT:
@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
return f"TO_DATE({this}, {time_format})"
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str:
if expression.expression.args.get("with"):
expression = expression.copy()
expression.set("with", expression.expression.args.pop("with"))
return self.insert_sql(expression)
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
@ -169,10 +180,7 @@ class Spark2(Hive):
class Generator(Hive.Generator):
QUERY_HINTS = True
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
}
NVL2_SUPPORTED = True
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
@ -197,6 +205,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
exp.Insert: _insert_sql,
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,

View file

@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@ -78,6 +79,7 @@ class SQLite(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@ -103,6 +105,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.Concat: concat_to_dpipe_sql,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),

View file

@ -95,6 +95,9 @@ class Teradata(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.DATABASE: lambda self: self.expression(
exp.Use, this=self._parse_table(schema=False)
),
TokenType.REPLACE: lambda self: self._parse_create(),
}
@ -165,6 +168,7 @@ class Teradata(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:

View file

@ -13,3 +13,6 @@ class Trino(Presto):
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
class Parser(Presto.Parser):
SUPPORTS_USER_DEFINED_TYPES = False

View file

@ -7,6 +7,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
any_value_to_max_sql,
max_or_greatest,
min_or_least,
parse_date_delta,
@ -79,22 +80,23 @@ def _format_time_lambda(
def _parse_format(args: t.List) -> exp.Expression:
assert len(args) == 2
this = seq_get(args, 0)
fmt = seq_get(args, 1)
culture = seq_get(args, 2)
fmt = args[1]
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name))
if number_fmt:
return exp.NumberToStr(this=args[0], format=fmt)
return exp.NumberToStr(this=this, format=fmt, culture=culture)
return exp.TimeToStr(
this=args[0],
format=exp.Literal.string(
if fmt:
fmt = exp.Literal.string(
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.TIME_MAPPING)
),
)
)
return exp.TimeToStr(this=this, format=fmt, culture=culture)
def _parse_eomonth(args: t.List) -> exp.Expression:
@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
def generate_date_delta_with_unit_sql(
self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
) -> str:
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
return self.func(func, expression.text("unit"), expression.expression, expression.this)
def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
)
)
)
return self.func("FORMAT", expression.this, fmt)
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = expression.this
@ -332,10 +334,12 @@ class TSQL(Dialect):
"SQL_VARIANT": TokenType.VARIANT,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"UPDATE STATISTICS": TokenType.COMMAND,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
class Parser(parser.Parser):
@ -395,7 +399,9 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
T-SQL supports the syntax alias = expression in the SELECT's projection list,
so we transform all parsed Selects to convert their EQ projections into Aliases.
@ -458,43 +464,6 @@ class TSQL(Dialect):
return self._parse_as_command(self._prev)
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
if self._match_text_seq("AS", "OF"):
system_time = self.expression(
exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
)
elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
kind = self._prev.text
this = self._parse_bitwise()
self._match_texts(("TO", "AND"))
expression = self._parse_bitwise()
system_time = self.expression(
exp.SystemTime, this=this, expression=expression, kind=kind
)
elif self._match_text_seq("CONTAINED", "IN"):
args = self._parse_wrapped_csv(self._parse_bitwise)
system_time = self.expression(
exp.SystemTime,
this=seq_get(args, 0),
expression=seq_get(args, 1),
kind="CONTAINED IN",
)
elif self._match(TokenType.ALL):
system_time = self.expression(exp.SystemTime, kind="ALL")
else:
system_time = None
self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
return system_time
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
@ -589,14 +558,36 @@ class TSQL(Dialect):
return create
def _parse_if(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match_text_seq("OBJECT_ID"):
self._parse_wrapped_csv(self._parse_string)
if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
return self._parse_drop(exists=True)
self._retreat(index)
return super()._parse_if()
def _parse_unique(self) -> exp.UniqueColumnConstraint:
return self.expression(
exp.UniqueColumnConstraint,
this=None
if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
else self._parse_schema(self._parse_id_var(any_token=False)),
)
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
NVL2_SUPPORTED = False
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.INT: "INTEGER",
@ -607,6 +598,8 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@ -651,25 +644,44 @@ class TSQL(Dialect):
return sql
def create_sql(self, expression: exp.Create) -> str:
expression = expression.copy()
kind = self.sql(expression, "kind").upper()
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
if exists:
table = expression.find(exp.Table)
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
if kind == "SCHEMA":
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
elif kind == "TABLE":
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
elif kind == "INDEX":
index = self.sql(exp.Literal.string(expression.this.text("this")))
sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
elif expression.args.get("replace"):
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
return sql
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
def systemtime_sql(self, expression: exp.SystemTime) -> str:
kind = expression.args["kind"]
if kind == "ALL":
return "FOR SYSTEM_TIME ALL"
def version_sql(self, expression: exp.Version) -> str:
name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name
this = f"FOR {name}"
expr = expression.expression
kind = expression.text("kind")
if kind in ("FROM", "BETWEEN"):
args = expr.expressions
sep = "TO" if kind == "FROM" else "AND"
expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}"
else:
expr_sql = self.sql(expr)
start = self.sql(expression, "this")
if kind == "AS OF":
return f"FOR SYSTEM_TIME AS OF {start}"
end = self.sql(expression, "expression")
if kind == "FROM":
return f"FOR SYSTEM_TIME FROM {start} TO {end}"
if kind == "BETWEEN":
return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
expr_sql = f" {expr_sql}" if expr_sql else ""
return f"{this} {kind}{expr_sql}"
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
@ -713,3 +725,16 @@ class TSQL(Dialect):
identifier = f"#{identifier}"
return identifier
def constraint_sql(self, expression: exp.Constraint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
# https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
start = self.sql(expression, "start") or "1"
increment = self.sql(expression, "increment") or "1"
return f"IDENTITY({start}, {increment})"

View file

@ -1035,12 +1035,13 @@ class Clone(Expression):
"this": True,
"when": False,
"kind": False,
"shallow": False,
"expression": False,
}
class Describe(Expression):
arg_types = {"this": True, "kind": False}
arg_types = {"this": True, "kind": False, "expressions": False}
class Pragma(Expression):
@ -1070,6 +1071,8 @@ class Show(Expression):
"like": False,
"where": False,
"db": False,
"scope": False,
"scope_kind": False,
"full": False,
"mutex": False,
"query": False,
@ -1207,6 +1210,10 @@ class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
class Comprehension(Expression):
arg_types = {"this": True, "expression": True, "iterator": True, "condition": False}
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
class MergeTreeTTLAction(Expression):
arg_types = {
@ -1269,6 +1276,10 @@ class CheckColumnConstraint(ColumnConstraintKind):
pass
class ClusteredColumnConstraint(ColumnConstraintKind):
pass
class CollateColumnConstraint(ColumnConstraintKind):
pass
@ -1316,6 +1327,14 @@ class InlineLengthColumnConstraint(ColumnConstraintKind):
pass
class NonClusteredColumnConstraint(ColumnConstraintKind):
pass
class NotForReplicationColumnConstraint(ColumnConstraintKind):
arg_types = {}
class NotNullColumnConstraint(ColumnConstraintKind):
arg_types = {"allow_null": False}
@ -1345,6 +1364,12 @@ class PathColumnConstraint(ColumnConstraintKind):
pass
# computed column expression
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
class ComputedColumnConstraint(ColumnConstraintKind):
arg_types = {"this": True, "persisted": False, "not_null": False}
class Constraint(Expression):
arg_types = {"this": True, "expressions": True}
@ -1489,6 +1514,15 @@ class Check(Expression):
pass
# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
class Connect(Expression):
arg_types = {"start": False, "connect": True}
class Prior(Expression):
pass
class Directory(Expression):
# https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html
arg_types = {"this": True, "local": False, "row_format": False}
@ -1578,6 +1612,7 @@ class Insert(DDL):
"alternative": False,
"where": False,
"ignore": False,
"by_name": False,
}
def with_(
@ -2045,8 +2080,12 @@ class NoPrimaryIndexProperty(Property):
arg_types = {}
class OnProperty(Property):
arg_types = {"this": True}
class OnCommitProperty(Property):
arg_type = {"delete": False}
arg_types = {"delete": False}
class PartitionedByProperty(Property):
@ -2282,6 +2321,16 @@ class Subqueryable(Unionable):
def named_selects(self) -> t.List[str]:
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
def select(
self,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Subqueryable:
raise NotImplementedError("Subqueryable objects must implement `select`")
def with_(
self,
alias: ExpOrStr,
@ -2323,6 +2372,7 @@ QUERY_MODIFIERS = {
"match": False,
"laterals": False,
"joins": False,
"connect": False,
"pivots": False,
"where": False,
"group": False,
@ -2363,6 +2413,7 @@ class Table(Expression):
"pivots": False,
"hints": False,
"system_time": False,
"version": False,
}
@property
@ -2403,21 +2454,13 @@ class Table(Expression):
return parts
# See the TSQL "Querying data in a system-versioned temporal table" page
class SystemTime(Expression):
arg_types = {
"this": False,
"expression": False,
"kind": True,
}
class Union(Subqueryable):
arg_types = {
"with": False,
"this": True,
"expression": True,
"distinct": False,
"by_name": False,
**QUERY_MODIFIERS,
}
@ -2529,6 +2572,7 @@ class Update(Expression):
"from": False,
"where": False,
"returning": False,
"order": False,
"limit": False,
}
@ -2545,6 +2589,20 @@ class Var(Expression):
pass
class Version(Expression):
"""
Time travel, iceberg, bigquery etc
https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots
https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html
https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of
https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16
this is either TIMESTAMP or VERSION
kind is ("AS OF", "BETWEEN")
"""
arg_types = {"this": True, "kind": True, "expression": False}
class Schema(Expression):
arg_types = {"this": False, "expressions": False}
@ -3263,6 +3321,23 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
def unwrap(self) -> Subquery:
expression = self
while expression.same_parent and expression.is_wrapper:
expression = t.cast(Subquery, expression.parent)
return expression
@property
def is_wrapper(self) -> bool:
"""
Whether this Subquery acts as a simple wrapper around another expression.
SELECT * FROM (((SELECT * FROM t)))
^
This corresponds to a "wrapper" Subquery node
"""
return all(v is None for k, v in self.args.items() if k != "this")
@property
def is_star(self) -> bool:
return self.this.is_star
@ -3313,7 +3388,7 @@ class Pivot(Expression):
}
class Window(Expression):
class Window(Condition):
arg_types = {
"this": True,
"partition_by": False,
@ -3375,7 +3450,7 @@ class Boolean(Condition):
pass
class DataTypeSize(Expression):
class DataTypeParam(Expression):
arg_types = {"this": True, "expression": False}
@ -3386,6 +3461,7 @@ class DataType(Expression):
"nested": False,
"values": False,
"prefix": False,
"kind": False,
}
class Type(AutoName):
@ -3432,6 +3508,7 @@ class DataType(Expression):
LOWCARDINALITY = auto()
MAP = auto()
MEDIUMBLOB = auto()
MEDIUMINT = auto()
MEDIUMTEXT = auto()
MONEY = auto()
NCHAR = auto()
@ -3475,6 +3552,7 @@ class DataType(Expression):
VARCHAR = auto()
VARIANT = auto()
XML = auto()
YEAR = auto()
TEXT_TYPES = {
Type.CHAR,
@ -3498,7 +3576,10 @@ class DataType(Expression):
Type.DOUBLE,
}
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
NUMERIC_TYPES = {
*INTEGER_TYPES,
*FLOAT_TYPES,
}
TEMPORAL_TYPES = {
Type.TIME,
@ -3511,23 +3592,39 @@ class DataType(Expression):
Type.DATETIME64,
}
META_TYPES = {"UNKNOWN", "NULL"}
@classmethod
def build(
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
cls,
dtype: str | DataType | DataType.Type,
dialect: DialectType = None,
udt: bool = False,
**kwargs,
) -> DataType:
"""
Constructs a DataType object.
Args:
dtype: the data type of interest.
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
kawrgs: additional arguments to pass in the constructor of DataType.
Returns:
The constructed DataType object.
"""
from sqlglot import parse_one
if isinstance(dtype, str):
upper = dtype.upper()
if upper in DataType.META_TYPES:
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
if dtype.upper() == "UNKNOWN":
return DataType(this=DataType.Type.UNKNOWN, **kwargs)
if data_type_exp is None:
raise ValueError(f"Unparsable data type value: {dtype}")
try:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
except ParseError:
if udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
raise
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
@ -3538,7 +3635,31 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
return any(self.this == DataType.build(dtype).this for dtype in dtypes)
"""
Checks whether this DataType matches one of the provided data types. Nested types or precision
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
Args:
dtypes: the data types to compare this DataType to.
Returns:
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
for dtype in dtypes:
other = DataType.build(dtype, udt=True)
if (
other.expressions
or self.this == DataType.Type.USERDEFINED
or other.this == DataType.Type.USERDEFINED
):
matches = self == other
else:
matches = self.this == other.this
if matches:
return True
return False
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@ -3546,6 +3667,11 @@ class PseudoType(Expression):
pass
# https://www.postgresql.org/docs/15/datatype-oid.html
class ObjectIdentifier(Expression):
pass
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
@ -4005,6 +4131,7 @@ class ArrayAny(Func):
class ArrayConcat(Func):
_sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"]
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@ -4047,7 +4174,15 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
arg_types = {"this": True, "having": False, "max": False}
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
class First(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Last(Func):
arg_types = {"this": True, "ignore_nulls": False}
class Case(Func):
@ -4086,18 +4221,29 @@ class Cast(Func):
return self.name
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
"""
Checks whether this Cast's DataType matches one of the provided data types. Nested types
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
array<int> != array<float>.
Args:
dtypes: the data types to compare this Cast's DataType to.
Returns:
True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType.
"""
return self.to.is_type(*dtypes)
class CastToStrType(Func):
arg_types = {"this": True, "expression": True}
class Collate(Binary):
class TryCast(Cast):
pass
class TryCast(Cast):
class CastToStrType(Func):
arg_types = {"this": True, "to": True}
class Collate(Binary):
pass
@ -4310,7 +4456,7 @@ class Greatest(Func):
is_var_len_args = True
class GroupConcat(Func):
class GroupConcat(AggFunc):
arg_types = {"this": True, "separator": False}
@ -4648,8 +4794,19 @@ class StrToUnix(Func):
arg_types = {"this": False, "format": False}
# https://prestodb.io/docs/current/functions/string.html
# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map
class StrToMap(Func):
arg_types = {
"this": True,
"pair_delim": False,
"key_value_delim": False,
"duplicate_resolution_callback": False,
}
class NumberToStr(Func):
arg_types = {"this": True, "format": True}
arg_types = {"this": True, "format": True, "culture": False}
class FromBase(Func):
@ -4665,6 +4822,13 @@ class StructExtract(Func):
arg_types = {"this": True, "expression": True}
# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16
# https://docs.snowflake.com/en/sql-reference/functions/insert
class Stuff(Func):
_sql_names = ["STUFF", "INSERT"]
arg_types = {"this": True, "start": True, "length": True, "expression": True}
class Sum(AggFunc):
pass
@ -4686,7 +4850,7 @@ class StddevSamp(AggFunc):
class TimeToStr(Func):
arg_types = {"this": True, "format": True}
arg_types = {"this": True, "format": True, "culture": False}
class TimeToTimeStr(Func):
@ -5724,9 +5888,9 @@ def table_(
The new Table instance.
"""
return Table(
this=to_identifier(table, quoted=quoted),
db=to_identifier(db, quoted=quoted),
catalog=to_identifier(catalog, quoted=quoted),
this=to_identifier(table, quoted=quoted) if table else None,
db=to_identifier(db, quoted=quoted) if db else None,
catalog=to_identifier(catalog, quoted=quoted) if catalog else None,
alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
@ -5844,8 +6008,8 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k, copy=copy) for k in value],
values=[convert(v, copy=copy) for v in value.values()],
keys=Array(expressions=[convert(k, copy=copy) for k in value]),
values=Array(expressions=[convert(v, copy=copy) for v in value.values()]),
)
raise ValueError(f"Cannot convert {value}")

View file

@ -8,7 +8,7 @@ from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
from sqlglot.tokens import Tokenizer, TokenType
logger = logging.getLogger("sqlglot")
@ -61,6 +61,7 @@ class Generator:
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
@ -78,7 +79,10 @@ class Generator:
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@ -171,6 +175,9 @@ class Generator:
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
# Whether or not the NVL2 function is supported
NVL2_SUPPORTED = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@ -179,6 +186,9 @@ class Generator:
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -245,6 +255,7 @@ class Generator:
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
@ -317,8 +328,7 @@ class Generator:
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
STRING_ESCAPE = "'"
IDENTIFIER_ESCAPE = '"'
TOKENIZER_CLASS = Tokenizer
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
@ -379,8 +389,10 @@ class Generator:
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
@ -626,6 +638,16 @@ class Generator:
kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
this = self.sql(expression, "this")
if expression.args.get("not_null"):
persisted = " PERSISTED NOT NULL"
elif expression.args.get("persisted"):
persisted = " PERSISTED"
else:
persisted = ""
return f"AS {this}{persisted}"
def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
@ -642,8 +664,8 @@ class Generator:
) -> str:
this = ""
if expression.this is not None:
on_null = "ON NULL " if expression.args.get("on_null") else ""
this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
on_null = " ON NULL" if expression.args.get("on_null") else ""
this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}"
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
@ -668,7 +690,7 @@ class Generator:
expr = self.sql(expression, "expression")
expr = f"({expr})" if expr else "IDENTITY"
return f"GENERATED{this}AS {expr}{sequence_opts}"
return f"GENERATED{this} AS {expr}{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -774,14 +796,16 @@ class Generator:
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
this = f"{shallow}CLONE {this}"
when = self.sql(expression, "when")
if when:
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
return f"CLONE {this} {when} ({kind} => {expr})"
return f"{this} {when} ({kind} => {expr})"
return f"CLONE {this}"
return this
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@ -830,7 +854,7 @@ class Generator:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
specifier = self.sql(expression, "expression")
specifier = f" {specifier}" if specifier else ""
@ -839,11 +863,14 @@ class Generator:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
type_sql = self.sql(expression, "kind")
else:
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
nested = ""
interior = self.expressions(expression, flat=True)
@ -943,9 +970,9 @@ class Generator:
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table} " if table else ""
table = f"{self.INDEX_ON} {table}" if table else ""
using = self.sql(expression, "using")
using = f"USING {using} " if using else ""
using = f" USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
@ -1171,6 +1198,7 @@ class Generator:
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
by_name = " BY NAME" if expression.args.get("by_name") else ""
returning = self.sql(expression, "returning")
if self.RETURNING_END:
@ -1178,7 +1206,7 @@ class Generator:
else:
expression_sql = f"{returning}{expression_sql}{conflict}"
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}"
sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@ -1196,6 +1224,9 @@ class Generator:
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
return expression.name.upper()
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
constraint = self.sql(expression, "constraint")
@ -1248,6 +1279,8 @@ class Generator:
if part
)
version = self.sql(expression, "version")
version = f" {version}" if version else ""
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=" ")
@ -1256,10 +1289,8 @@ class Generator:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@ -1314,6 +1345,12 @@ class Generator:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
def version_sql(self, expression: exp.Version) -> str:
this = f"FOR {expression.name}"
kind = expression.text("kind")
expr = self.sql(expression, "expression")
return f"{this} {kind} {expr}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
@ -1323,12 +1360,13 @@ class Generator:
from_sql = self.sql(expression, "from")
where_sql = self.sql(expression, "where")
returning = self.sql(expression, "returning")
order = self.sql(expression, "order")
limit = self.sql(expression, "limit")
if self.RETURNING_END:
expression_sql = f"{from_sql}{where_sql}{returning}{limit}"
expression_sql = f"{from_sql}{where_sql}{returning}"
else:
expression_sql = f"{returning}{from_sql}{where_sql}{limit}"
sql = f"UPDATE {this} SET {set_sql}{expression_sql}"
expression_sql = f"{returning}{from_sql}{where_sql}"
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
@ -1425,6 +1463,16 @@ class Generator:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
def connect_sql(self, expression: exp.Connect) -> str:
start = self.sql(expression, "start")
start = self.seg(f"START WITH {start}") if start else ""
connect = self.sql(expression, "connect")
connect = self.seg(f"CONNECT BY {connect}")
return start + connect
def prior_sql(self, expression: exp.Prior) -> str:
return f"PRIOR {self.sql(expression, 'this')}"
def join_sql(self, expression: exp.Join) -> str:
op_sql = " ".join(
op
@ -1667,6 +1715,7 @@ class Generator:
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "connect"),
self.sql(expression, "match"),
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
@ -1801,7 +1850,8 @@ class Generator:
def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
by_name = " BY NAME" if expression.args.get("by_name") else ""
return f"UNION{kind}{by_name}"
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
@ -2224,7 +2274,14 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
actions = self.expressions(
expression,
key="actions",
prefix="ADD COLUMN ",
)
else:
actions = f"ADD {self.expressions(expression, key='actions')}"
elif isinstance(actions[0], exp.Schema):
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
@ -2525,10 +2582,21 @@ class Generator:
return f"WHEN {matched}{source}{condition} THEN {then}"
def merge_sql(self, expression: exp.Merge) -> str:
this = self.sql(expression, "this")
table = expression.this
table_alias = ""
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
using = f"USING {self.sql(expression, 'using')}"
on = f"ON {self.sql(expression, 'on')}"
return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"
expressions = self.expressions(expression, sep=" ")
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@ -2631,6 +2699,29 @@ class Generator:
options = f" {options}" if options else ""
return f"{kind}{this}{type_}{schema}{options}"
def nvl2_sql(self, expression: exp.Nvl2) -> str:
if self.NVL2_SUPPORTED:
return self.function_fallback_sql(expression)
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
expression.args["true"].copy(),
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
case.else_(else_cond.copy(), copy=False)
return self.sql(case)
def comprehension_sql(self, expression: exp.Comprehension) -> str:
this = self.sql(expression, "this")
expr = self.sql(expression, "expression")
iterator = self.sql(expression, "iterator")
condition = self.sql(expression, "condition")
condition = f" IF {condition}" if condition else ""
return f"{this} FOR {expr} IN {iterator}{condition}"
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None

View file

@ -33,6 +33,15 @@ class AutoName(Enum):
return name
class classproperty(property):
"""
Similar to a normal property but works for class methods
"""
def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
return classmethod(self.fget).__get__(None, owner)() # type: ignore
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try:
@ -137,9 +146,9 @@ def subclasses(
def apply_index_offset(
this: exp.Expression,
expressions: t.List[t.Optional[E]],
expressions: t.List[E],
offset: int,
) -> t.List[t.Optional[E]]:
) -> t.List[E]:
"""
Applies an offset to a given integer literal expression.
@ -170,15 +179,14 @@ def apply_index_offset(
):
return expressions
if expression:
if not expression.type:
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
return [expression]
if not expression.type:
annotate_types(expression)
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
logger.warning("Applying array index offset (%s)", offset)
expression = simplify(
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
)
return [expression]
return expressions

View file

@ -1,2 +1,9 @@
from sqlglot.optimizer.optimizer import RULES, optimize
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
from sqlglot.optimizer.scope import (
Scope,
build_scope,
find_all_in_scope,
find_in_scope,
traverse_scope,
walk_in_scope,
)

View file

@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
NESTED_TYPES = {
exp.DataType.Type.ARRAY,
}
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
if isinstance(type2, exp.DataType):
type2 = type2.this
) -> exp.DataType | exp.DataType.Type:
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
if exp.DataType.Type.NULL in (type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
if exp.DataType.Type.NULL in (type1_value, type2_value):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
def _annotate_by_args(
self, expression: E, *args: str, promote: bool = False, array: bool = False
) -> E:
self._annotate_args(expression)
expressions: t.List[exp.Expression] = []
@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
elif expression.type.this in exp.DataType.FLOAT_TYPES:
expression.type = exp.DataType.Type.DOUBLE
if array:
expression.type = exp.DataType(
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
)
return expression

View file

@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
to_replace = scope.expression.parent.unwrap()
name, cte = _new_cte(scope, existing_ctes, taken)
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
table.set("joins", to_replace.args.get("joins"))
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
table.set("joins", parent.args.get("joins"))
to_replace.replace(table)
parent.replace(table)
return cte

View file

@ -72,8 +72,13 @@ def normalize(expression):
if not any(join.args.get(k) for k in JOIN_ATTRS):
join.set("kind", "CROSS")
if join.kind != "CROSS":
if join.kind == "CROSS":
join.set("on", None)
else:
join.set("kind", None)
if not join.args.get("on") and not join.args.get("using"):
join.set("on", exp.true())
return expression

View file

@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def pushdown_dnf(predicates, scope, scope_ref_count):
@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if isinstance(node, exp.Join):
node.on(predicate, copy=False)
elif isinstance(node, exp.Select):
node.where(replace_aliases(node, predicate), copy=False)
inner_predicate = replace_aliases(node, predicate)
if find_in_scope(inner_predicate, exp.AggFunc):
node.having(inner_predicate, copy=False)
else:
node.where(inner_predicate, copy=False)
def nodes_for_predicate(predicate, sources, scope_ref_count):

View file

@ -6,7 +6,7 @@ from enum import Enum, auto
from sqlglot import exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import find_new_name
from sqlglot.helper import ensure_collection, find_new_name
logger = logging.getLogger("sqlglot")
@ -141,38 +141,10 @@ class Scope:
return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression_types (type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
"""
@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
for key in ("joins", "laterals", "pivots"):
for arg in node.args.get(key) or []:
yield from walk_in_scope(arg, bfs=bfs)
def find_all_in_scope(expression, expression_types, bfs=True):
"""
Returns a generator object which visits all nodes in this scope and only yields those that
match at least one of the specified expression types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
"""
for expression, *_ in walk_in_scope(expression, bfs=bfs):
if isinstance(expression, tuple(ensure_collection(expression_types))):
yield expression
def find_in_scope(expression, expression_types, bfs=True):
"""
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Args:
expression (exp.Expression):
expression_types (tuple[type]|type): the expression type(s) to match.
bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching
the criteria was found.
"""
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

View file

@ -69,10 +69,10 @@ def simplify(expression):
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_compliments(node, root)
node = simplify_coalesce(node)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
node = simplify_coalesce(node)
if root:
expression.replace(node)
@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True):
def simplify_literals(expression, root=True):
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
if isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
value = this.name
@ -430,13 +431,14 @@ def simplify_parens(expression):
if not isinstance(this, exp.Select) and (
not isinstance(parent, (exp.Condition, exp.Binary))
or isinstance(this, exp.Predicate)
or isinstance(parent, exp.Paren)
or not isinstance(this, exp.Binary)
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
):
return expression.this
return this
return expression
@ -488,18 +490,20 @@ def simplify_coalesce(expression):
coalesce = coalesce if coalesce.expressions else coalesce.this
# This expression is more complex than when we started, but it will get simplified further
return exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
return exp.paren(
exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
)
)
@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True):
for b in queue:
result = simplifier(expression, a, b)
if result:
if result and result is not expression:
queue.remove(b)
queue.appendleft(result)
break

View file

@ -136,6 +136,7 @@ class Parser(metaclass=_Parser):
TokenType.UINT128,
TokenType.INT256,
TokenType.UINT256,
TokenType.MEDIUMINT,
TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
@ -186,6 +187,7 @@ class Parser(metaclass=_Parser):
TokenType.SMALLSERIAL,
TokenType.BIGSERIAL,
TokenType.XML,
TokenType.YEAR,
TokenType.UNIQUEIDENTIFIER,
TokenType.USERDEFINED,
TokenType.MONEY,
@ -194,9 +196,12 @@ class Parser(metaclass=_Parser):
TokenType.IMAGE,
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.OBJECT_IDENTIFIER,
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
TokenType.UNKNOWN,
TokenType.NULL,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
}
@ -332,6 +337,7 @@ class Parser(metaclass=_Parser):
TokenType.INDEX,
TokenType.ISNULL,
TokenType.ILIKE,
TokenType.INSERT,
TokenType.LIKE,
TokenType.MERGE,
TokenType.OFFSET,
@ -487,7 +493,7 @@ class Parser(metaclass=_Parser):
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
exp.Expression: lambda self: self._parse_statement(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
@ -523,9 +529,6 @@ class Parser(metaclass=_Parser):
TokenType.DESC: lambda self: self._parse_describe(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.FROM: lambda self: exp.select("*").from_(
t.cast(exp.From, self._parse_from(skip_from_token=True))
),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD: lambda self: self._parse_load(),
TokenType.MERGE: lambda self: self._parse_merge(),
@ -578,7 +581,7 @@ class Parser(metaclass=_Parser):
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
TokenType.PARAMETER: lambda self: self._parse_parameter(),
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
if self._match_set((TokenType.NUMBER, TokenType.VAR))
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
else None,
}
@ -593,6 +596,7 @@ class Parser(metaclass=_Parser):
TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
TokenType.FOR: lambda self, this: self._parse_comprehension(this),
}
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
@ -684,6 +688,12 @@ class Parser(metaclass=_Parser):
exp.CommentColumnConstraint, this=self._parse_string()
),
"COMPRESS": lambda self: self._parse_compress(),
"CLUSTERED": lambda self: self.expression(
exp.ClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
),
"NONCLUSTERED": lambda self: self.expression(
exp.NonClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
),
"DEFAULT": lambda self: self.expression(
exp.DefaultColumnConstraint, this=self._parse_bitwise()
),
@ -698,8 +708,11 @@ class Parser(metaclass=_Parser):
"LIKE": lambda self: self._parse_create_like(),
"NOT": lambda self: self._parse_not_constraint(),
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
"ON": lambda self: self._match(TokenType.UPDATE)
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
"ON": lambda self: (
self._match(TokenType.UPDATE)
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function())
)
or self.expression(exp.OnProperty, this=self._parse_id_var()),
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
"PRIMARY KEY": lambda self: self._parse_primary_key(),
"REFERENCES": lambda self: self._parse_references(match=False),
@ -709,6 +722,9 @@ class Parser(metaclass=_Parser):
"TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]),
"UNIQUE": lambda self: self._parse_unique(),
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
"WITH": lambda self: self.expression(
exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property)
),
}
ALTER_PARSERS = {
@ -728,6 +744,11 @@ class Parser(metaclass=_Parser):
"NEXT": lambda self: self._parse_next_value_for(),
}
INVALID_FUNC_NAME_TOKENS = {
TokenType.IDENTIFIER,
TokenType.STRING,
}
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS = {
@ -774,6 +795,8 @@ class Parser(metaclass=_Parser):
self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
),
TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)),
TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)),
TokenType.START_WITH: lambda self: ("connect", self._parse_connect()),
}
SET_PARSERS = {
@ -815,6 +838,8 @@ class Parser(metaclass=_Parser):
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
DISTINCT_TOKENS = {TokenType.DISTINCT}
STRICT_CAST = True
# A NULL arg in CONCAT yields NULL by default
@ -826,6 +851,11 @@ class Parser(metaclass=_Parser):
LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
SUPPORTS_USER_DEFINED_TYPES = True
# Whether or not ADD is present for each column added by ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
__slots__ = (
"error_level",
"error_message_context",
@ -838,9 +868,11 @@ class Parser(metaclass=_Parser):
"_next",
"_prev",
"_prev_comments",
"_tokenizer",
)
# Autofilled
TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer
INDEX_OFFSET: int = 0
UNNEST_COLUMN_ONLY: bool = False
ALIAS_POST_TABLESAMPLE: bool = False
@ -863,6 +895,7 @@ class Parser(metaclass=_Parser):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
self.max_errors = max_errors
self._tokenizer = self.TOKENIZER_CLASS()
self.reset()
def reset(self):
@ -1148,7 +1181,7 @@ class Parser(metaclass=_Parser):
expression = self._parse_set_operations(expression) if expression else self._parse_select()
return self._parse_query_modifiers(expression)
def _parse_drop(self) -> exp.Drop | exp.Command:
def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match_text_seq("MATERIALIZED")
@ -1160,7 +1193,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Drop,
comments=start.comments,
exists=self._parse_exists(),
exists=exists or self._parse_exists(),
this=self._parse_table(schema=True),
kind=kind,
temporary=temporary,
@ -1274,6 +1307,8 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
no_schema_binding = True
shallow = self._match_text_seq("SHALLOW")
if self._match_text_seq("CLONE"):
clone = self._parse_table(schema=True)
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
@ -1285,7 +1320,12 @@ class Parser(metaclass=_Parser):
clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
self._match(TokenType.R_PAREN)
clone = self.expression(
exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression
exp.Clone,
this=clone,
when=when,
kind=clone_kind,
shallow=shallow,
expression=clone_expression,
)
return self.expression(
@ -1349,7 +1389,11 @@ class Parser(metaclass=_Parser):
if assignment:
key = self._parse_var_or_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
return self.expression(
exp.Property,
this=key,
value=self._parse_column() or self._parse_var(any_token=True),
)
return None
@ -1409,7 +1453,7 @@ class Parser(metaclass=_Parser):
def _parse_with_property(
self,
) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
@ -1622,7 +1666,7 @@ class Parser(metaclass=_Parser):
override=override,
)
def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_partition_by(self) -> t.List[exp.Expression]:
if self._match(TokenType.PARTITION_BY):
return self._parse_csv(self._parse_conjunction)
return []
@ -1652,9 +1696,9 @@ class Parser(metaclass=_Parser):
def _parse_on_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
return exp.OnCommitProperty()
elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
if self._match_text_seq("COMMIT", "DELETE", "ROWS"):
return exp.OnCommitProperty(delete=True)
return None
return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@ -1709,8 +1753,10 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
return self.expression(exp.Describe, this=this, kind=kind)
this = self._parse_table(schema=True)
properties = self._parse_properties()
expressions = properties.expressions if properties else None
return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
@ -1741,6 +1787,7 @@ class Parser(metaclass=_Parser):
exp.Insert,
comments=comments,
this=this,
by_name=self._match_text_seq("BY", "NAME"),
exists=self._parse_exists(),
partition=self._parse_partition(),
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
@ -1895,6 +1942,7 @@ class Parser(metaclass=_Parser):
"from": self._parse_from(joins=True),
"where": self._parse_where(),
"returning": returning or self._parse_returning(),
"order": self._parse_order(),
"limit": self._parse_limit(),
},
)
@ -1948,13 +1996,14 @@ class Parser(metaclass=_Parser):
# https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_projections(self) -> t.List[exp.Expression]:
return self._parse_expressions()
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
) -> t.Optional[exp.Expression]:
cte = self._parse_with()
if cte:
this = self._parse_statement()
@ -1967,12 +2016,18 @@ class Parser(metaclass=_Parser):
else:
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
return this
# duckdb supports leading with FROM x
from_ = self._parse_from() if self._match(TokenType.FROM, advance=False) else None
if self._match(TokenType.SELECT):
comments = self._prev_comments
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
distinct = self._match_set(self.DISTINCT_TOKENS)
kind = (
self._match(TokenType.ALIAS)
@ -2006,7 +2061,9 @@ class Parser(metaclass=_Parser):
if into:
this.set("into", into)
from_ = self._parse_from()
if not from_:
from_ = self._parse_from()
if from_:
this.set("from", from_)
@ -2033,6 +2090,8 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
elif from_:
this = exp.select("*").from_(from_.this, copy=False)
else:
this = None
@ -2491,6 +2550,11 @@ class Parser(metaclass=_Parser):
if schema:
return self._parse_schema(this=this)
version = self._parse_version()
if version:
this.set("version", version)
if self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
@ -2498,11 +2562,11 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
this.set("hints", self._parse_table_hints())
if not this.args.get("pivots"):
this.set("pivots", self._parse_pivots())
this.set("hints", self._parse_table_hints())
if not self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
@ -2516,6 +2580,37 @@ class Parser(metaclass=_Parser):
return this
def _parse_version(self) -> t.Optional[exp.Version]:
if self._match(TokenType.TIMESTAMP_SNAPSHOT):
this = "TIMESTAMP"
elif self._match(TokenType.VERSION_SNAPSHOT):
this = "VERSION"
else:
return None
if self._match_set((TokenType.FROM, TokenType.BETWEEN)):
kind = self._prev.text.upper()
start = self._parse_bitwise()
self._match_texts(("TO", "AND"))
end = self._parse_bitwise()
expression: t.Optional[exp.Expression] = self.expression(
exp.Tuple, expressions=[start, end]
)
elif self._match_text_seq("CONTAINED", "IN"):
kind = "CONTAINED IN"
expression = self.expression(
exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise)
)
elif self._match(TokenType.ALL):
kind = "ALL"
expression = None
else:
self._match_text_seq("AS", "OF")
kind = "AS OF"
expression = self._parse_type()
return self.expression(exp.Version, this=this, expression=expression, kind=kind)
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if not self._match(TokenType.UNNEST):
return None
@ -2760,7 +2855,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Group, **elements) # type: ignore
def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.GROUPING_SETS):
return None
@ -2784,6 +2879,22 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]:
if skip_start_token:
start = None
elif self._match(TokenType.START_WITH):
start = self._parse_conjunction()
else:
return None
self._match(TokenType.CONNECT_BY)
self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
exp.Prior, this=self._parse_bitwise()
)
connect = self._parse_conjunction()
self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
return self.expression(exp.Connect, start=start, connect=connect)
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
@ -2929,6 +3040,7 @@ class Parser(metaclass=_Parser):
expression,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
by_name=self._match_text_seq("BY", "NAME"),
expression=self._parse_set_operations(self._parse_select(nested=True)),
)
@ -3017,6 +3129,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Escape, this=this, expression=self._parse_string())
def _parse_interval(self) -> t.Optional[exp.Interval]:
index = self._index
if not self._match(TokenType.INTERVAL):
return None
@ -3025,7 +3139,11 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_term()
unit = self._parse_function() or self._parse_var()
if not this:
self._retreat(index)
return None
unit = self._parse_function() or self._parse_var(any_token=True)
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
@ -3036,12 +3154,12 @@ class Parser(metaclass=_Parser):
if len(parts) == 2:
if unit:
# this is not actually a unit, it's something else
# This is not actually a unit, it's something else (e.g. a "window side")
unit = None
self._retreat(self._index - 1)
else:
this = exp.Literal.string(parts[0])
unit = self.expression(exp.Var, this=parts[1])
this = exp.Literal.string(parts[0])
unit = self.expression(exp.Var, this=parts[1])
return self.expression(exp.Interval, this=this, unit=unit)
@ -3087,7 +3205,7 @@ class Parser(metaclass=_Parser):
return interval
index = self._index
data_type = self._parse_types(check_func=True)
data_type = self._parse_types(check_func=True, allow_identifiers=False)
this = self._parse_column()
if data_type:
@ -3103,30 +3221,50 @@ class Parser(metaclass=_Parser):
return this
def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]:
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
this = self._parse_type()
if not this:
return None
return self.expression(
exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True)
)
def _parse_types(
self, check_func: bool = False, schema: bool = False
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
index = self._index
prefix = self._match_text_seq("SYSUDTLIB", ".")
if not self._match_set(self.TYPE_TOKENS):
return None
identifier = allow_identifiers and self._parse_id_var(
any_token=False, tokens=(TokenType.VAR,)
)
if identifier:
tokens = self._tokenizer.tokenize(identifier.name)
if len(tokens) != 1:
self.raise_error("Unexpected identifier", self._prev)
if tokens[0].token_type in self.TYPE_TOKENS:
self._prev = tokens[0]
elif self.SUPPORTS_USER_DEFINED_TYPES:
return identifier
else:
return None
else:
return None
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
return self.expression(exp.PseudoType, this=self._prev.text)
if type_token == TokenType.OBJECT_IDENTIFIER:
return self.expression(exp.ObjectIdentifier, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
expressions = None
@ -3137,7 +3275,9 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
lambda: self._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
@ -3151,14 +3291,16 @@ class Parser(metaclass=_Parser):
maybe_func = True
this: t.Optional[exp.Expression] = None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
values: t.Optional[t.List[exp.Expression]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
lambda: self._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
)
if not self._match(TokenType.GT):
@ -3355,7 +3497,7 @@ class Parser(metaclass=_Parser):
upper = this.upper()
parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper)
if optional_parens and parser:
if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS:
self._advance()
return parser(self)
@ -3442,7 +3584,9 @@ class Parser(metaclass=_Parser):
index = self._index
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_id_var)
expressions = t.cast(
t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
)
if not self._match(TokenType.R_PAREN):
self._retreat(index)
@ -3481,14 +3625,14 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
args = self._parse_csv(
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
)
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
def _parse_field_def(self) -> t.Optional[exp.Expression]:
return self._parse_column_def(self._parse_field(any_token=True))
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
@ -3499,7 +3643,18 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("FOR", "ORDINALITY"):
return self.expression(exp.ColumnDef, this=this, ordinality=True)
constraints = []
constraints: t.List[exp.Expression] = []
if not kind and self._match(TokenType.ALIAS):
constraints.append(
self.expression(
exp.ComputedColumnConstraint,
this=self._parse_conjunction(),
persisted=self._match_text_seq("PERSISTED"),
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
)
)
while True:
constraint = self._parse_column_constraint()
if not constraint:
@ -3553,7 +3708,7 @@ class Parser(metaclass=_Parser):
identity = self._match_text_seq("IDENTITY")
if self._match(TokenType.L_PAREN):
if self._match_text_seq("START", "WITH"):
if self._match(TokenType.START_WITH):
this.set("start", self._parse_bitwise())
if self._match_text_seq("INCREMENT", "BY"):
this.set("increment", self._parse_bitwise())
@ -3580,11 +3735,13 @@ class Parser(metaclass=_Parser):
def _parse_not_constraint(
self,
) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]:
) -> t.Optional[exp.Expression]:
if self._match_text_seq("NULL"):
return self.expression(exp.NotNullColumnConstraint)
if self._match_text_seq("CASESPECIFIC"):
return self.expression(exp.CaseSpecificColumnConstraint, not_=True)
if self._match_text_seq("FOR", "REPLICATION"):
return self.expression(exp.NotForReplicationColumnConstraint)
return None
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
@ -3729,7 +3886,7 @@ class Parser(metaclass=_Parser):
bracket_kind = self._prev.token_type
if self._match(TokenType.COLON):
expressions: t.List[t.Optional[exp.Expression]] = [
expressions: t.List[exp.Expression] = [
self.expression(exp.Slice, expression=self._parse_conjunction())
]
else:
@ -3844,17 +4001,17 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS):
if self._match(TokenType.COMMA):
return self.expression(
exp.CastToStrType, this=this, expression=self._parse_string()
)
else:
self.raise_error("Expected AS after CAST")
return self.expression(exp.CastToStrType, this=this, to=self._parse_string())
self.raise_error("Expected AS after CAST")
fmt = None
to = self._parse_types()
if not to:
self.raise_error("Expected TYPE after CAST")
elif isinstance(to, exp.Identifier):
to = exp.DataType.build(to.name, udt=True)
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
@ -3908,7 +4065,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.COMMA):
args.extend(self._parse_csv(self._parse_conjunction))
else:
args = self._parse_csv(self._parse_conjunction)
args = self._parse_csv(self._parse_conjunction) # type: ignore
index = self._index
if not self._match(TokenType.R_PAREN) and args:
@ -3991,10 +4148,10 @@ class Parser(metaclass=_Parser):
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
key = self._parse_field()
self._match(TokenType.COLON)
key = self._parse_column()
self._match_set((TokenType.COLON, TokenType.COMMA))
self._match_text_seq("VALUE")
value = self._parse_field()
value = self._parse_bitwise()
if not key and not value:
return None
@ -4116,7 +4273,7 @@ class Parser(metaclass=_Parser):
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
args = self._parse_csv(self._parse_bitwise)
args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise))
if self._match(TokenType.FROM):
args.append(self._parse_bitwise())
@ -4149,7 +4306,7 @@ class Parser(metaclass=_Parser):
exp.Trim, this=this, position=position, expression=expression, collation=collation
)
def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]:
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
def _parse_named_window(self) -> t.Optional[exp.Expression]:
@ -4216,8 +4373,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("LAST"):
first = False
partition = self._parse_partition_by()
order = self._parse_order()
partition, order = self._parse_partition_and_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
if kind:
@ -4256,6 +4412,11 @@ class Parser(metaclass=_Parser):
return window
def _parse_partition_and_order(
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return self._parse_partition_by(), self._parse_order()
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
@ -4377,14 +4538,14 @@ class Parser(metaclass=_Parser):
self._advance(-1)
return None
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_column)
return self._parse_csv(self._parse_column)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.REPLACE):
return None
if self._match(TokenType.L_PAREN, advance=False):
@ -4393,7 +4554,7 @@ class Parser(metaclass=_Parser):
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
) -> t.List[t.Optional[exp.Expression]]:
) -> t.List[exp.Expression]:
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
@ -4420,12 +4581,12 @@ class Parser(metaclass=_Parser):
return this
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return self._parse_wrapped_csv(self._parse_id_var, optional=optional)
def _parse_wrapped_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
) -> t.List[t.Optional[exp.Expression]]:
) -> t.List[exp.Expression]:
return self._parse_wrapped(
lambda: self._parse_csv(parse_method, sep=sep), optional=optional
)
@ -4439,7 +4600,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return parse_result
def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_expressions(self) -> t.List[exp.Expression]:
return self._parse_csv(self._parse_expression)
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
@ -4498,7 +4659,7 @@ class Parser(metaclass=_Parser):
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
expression = self._parse_field_def()
if expression:
expression.set("exists", exists_column)
@ -4549,13 +4710,16 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AddConstraint, this=this, expression=expression)
def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_alter_table_add(self) -> t.List[exp.Expression]:
index = self._index - 1
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
return self._parse_csv(self._parse_add_constraint)
self._retreat(index)
if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"):
return self._parse_csv(self._parse_field_def)
return self._parse_csv(self._parse_add_column)
def _parse_alter_table_alter(self) -> exp.AlterColumn:
@ -4576,7 +4740,7 @@ class Parser(metaclass=_Parser):
using=self._match(TokenType.USING) and self._parse_conjunction(),
)
def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
def _parse_alter_table_drop(self) -> t.List[exp.Expression]:
index = self._index - 1
partition_exists = self._parse_exists()
@ -4619,6 +4783,9 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
target = self._parse_table()
if target and self._match(TokenType.ALIAS, advance=False):
target.set("alias", self._parse_table_alias())
self._match(TokenType.USING)
using = self._parse_table()
@ -4685,8 +4852,7 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
return self._parse_as_command(self._prev)
def _parse_set_item_assignment(
self, kind: t.Optional[str] = None
@ -4786,6 +4952,19 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.DictRange, this=this, min=min, max=max)
def _parse_comprehension(self, this: exp.Expression) -> exp.Comprehension:
expression = self._parse_column()
self._match(TokenType.IN)
iterator = self._parse_column()
condition = self._parse_conjunction() if self._match_text_seq("IF") else None
return self.expression(
exp.Comprehension,
this=this,
expression=expression,
iterator=iterator,
condition=condition,
)
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:

View file

@ -48,6 +48,7 @@ class TokenType(AutoName):
HASH_ARROW = auto()
DHASH_ARROW = auto()
LR_ARROW = auto()
DAT = auto()
LT_AT = auto()
AT_GT = auto()
DOLLAR = auto()
@ -84,6 +85,7 @@ class TokenType(AutoName):
UTINYINT = auto()
SMALLINT = auto()
USMALLINT = auto()
MEDIUMINT = auto()
INT = auto()
UINT = auto()
BIGINT = auto()
@ -140,6 +142,7 @@ class TokenType(AutoName):
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
YEAR = auto()
UNIQUEIDENTIFIER = auto()
USERDEFINED = auto()
MONEY = auto()
@ -157,6 +160,7 @@ class TokenType(AutoName):
FIXEDSTRING = auto()
LOWCARDINALITY = auto()
NESTED = auto()
UNKNOWN = auto()
# keywords
ALIAS = auto()
@ -180,6 +184,7 @@ class TokenType(AutoName):
COMMAND = auto()
COMMENT = auto()
COMMIT = auto()
CONNECT_BY = auto()
CONSTRAINT = auto()
CREATE = auto()
CROSS = auto()
@ -256,6 +261,7 @@ class TokenType(AutoName):
NEXT = auto()
NOTNULL = auto()
NULL = auto()
OBJECT_IDENTIFIER = auto()
OFFSET = auto()
ON = auto()
ORDER_BY = auto()
@ -298,6 +304,7 @@ class TokenType(AutoName):
SIMILAR_TO = auto()
SOME = auto()
SORT_BY = auto()
START_WITH = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
@ -319,6 +326,8 @@ class TokenType(AutoName):
WINDOW = auto()
WITH = auto()
UNIQUE = auto()
VERSION_SNAPSHOT = auto()
TIMESTAMP_SNAPSHOT = auto()
class Token:
@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer):
"COLLATE": TokenType.COLLATE,
"COLUMN": TokenType.COLUMN,
"COMMIT": TokenType.COMMIT,
"CONNECT BY": TokenType.CONNECT_BY,
"CONSTRAINT": TokenType.CONSTRAINT,
"CREATE": TokenType.CREATE,
"CROSS": TokenType.CROSS,
@ -636,6 +646,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SIMILAR TO": TokenType.SIMILAR_TO,
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
@ -643,6 +654,7 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"UNION": TokenType.UNION,
"UNKNOWN": TokenType.UNKNOWN,
"UNNEST": TokenType.UNNEST,
"UNPIVOT": TokenType.UNPIVOT,
"UPDATE": TokenType.UPDATE,
@ -739,6 +751,8 @@ class Tokenizer(metaclass=_Tokenizer):
"TRUNCATE": TokenType.COMMAND,
"VACUUM": TokenType.COMMAND,
"USER-DEFINED": TokenType.USERDEFINED,
"FOR VERSION": TokenType.VERSION_SNAPSHOT,
"FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT,
}
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
@ -941,8 +955,8 @@ class Tokenizer(metaclass=_Tokenizer):
if result == TrieResult.EXISTS:
word = chars
end = self._current + size
size += 1
end = self._current - 1 + size
if end < self.size:
char = self.sql[end]
@ -961,21 +975,20 @@ class Tokenizer(metaclass=_Tokenizer):
char = ""
chars = " "
if not word:
if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
if word:
if self._scan_string(word):
return
self._scan_var()
if self._scan_comment(word):
return
if prev_space or single_token or not char:
self._advance(size - 1)
word = word.upper()
self._add(self.KEYWORDS[word], text=word)
return
if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
if self._scan_string(word):
return
if self._scan_comment(word):
return
self._advance(size - 1)
word = word.upper()
self._add(self.KEYWORDS[word], text=word)
self._scan_var()
def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS:
@ -1053,8 +1066,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)
return self._advance(-len(literal))
self._advance(-len(literal))
return self._add(TokenType.NUMBER, number_text)
else:
return self._add(TokenType.NUMBER)

View file

@ -68,11 +68,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
if order:
window.set("order", order.pop().copy())
else:
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
return (
exp.select(*outer_selects)
.from_(expression.subquery())
.where(exp.column(row_number).eq(1))
)
return expression
@ -126,7 +132,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
"""
for node in expression.find_all(exp.DataType):
node.set(
"expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
"expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
)
return expression