Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
985db29269
commit
53cf4a81a6
124 changed files with 60313 additions and 50346 deletions
|
@ -21,10 +21,12 @@ Currently many of the common operations are covered and more functionality will
|
|||
* Ex: `['cola', 'colb']`
|
||||
* The lack of types may limit functionality in future releases.
|
||||
* See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally.
|
||||
* If your output SQL dialect is not Spark, then configure the SparkSession to use that dialect
|
||||
* Ex: `SparkSession().builder.config("sqlframe.dialect", "bigquery").getOrCreate()`
|
||||
* See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
|
||||
* Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command.
|
||||
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
|
||||
* Spark is the default output dialect. See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects.
|
||||
* Ex: `.sql(pretty=True, dialect='bigquery')`
|
||||
* In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects.
|
||||
* Ex: `.sql(pretty=True)`
|
||||
|
||||
## Examples
|
||||
|
||||
|
@ -33,6 +35,8 @@ import sqlglot
|
|||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
|
||||
dialect = "spark"
|
||||
|
||||
sqlglot.schema.add_table(
|
||||
'employee',
|
||||
{
|
||||
|
@ -41,10 +45,10 @@ sqlglot.schema.add_table(
|
|||
'lname': 'STRING',
|
||||
'age': 'INT',
|
||||
},
|
||||
dialect="spark",
|
||||
dialect=dialect,
|
||||
) # Register the table structure prior to reading from the table
|
||||
|
||||
spark = SparkSession()
|
||||
spark = SparkSession.builder.config("sqlframe.dialect", dialect).getOrCreate()
|
||||
|
||||
df = (
|
||||
spark
|
||||
|
@ -53,7 +57,7 @@ df = (
|
|||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
)
|
||||
|
||||
print(df.sql(pretty=True)) # Spark will be the dialect used by default
|
||||
print(df.sql(pretty=True))
|
||||
```
|
||||
|
||||
```sparksql
|
||||
|
@ -81,7 +85,7 @@ class ExternalSchema(Schema):
|
|||
|
||||
sqlglot.schema = ExternalSchema()
|
||||
|
||||
spark = SparkSession()
|
||||
spark = SparkSession() # Spark will be used by default is not specific in SparkSession config
|
||||
|
||||
df = (
|
||||
spark
|
||||
|
@ -119,11 +123,14 @@ schema = types.StructType([
|
|||
])
|
||||
|
||||
sql_statements = (
|
||||
SparkSession()
|
||||
SparkSession
|
||||
.builder
|
||||
.config("sqlframe.dialect", "bigquery")
|
||||
.getOrCreate()
|
||||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
.sql(dialect="bigquery")
|
||||
.sql()
|
||||
)
|
||||
|
||||
result = None
|
||||
|
@ -166,11 +173,14 @@ schema = types.StructType([
|
|||
])
|
||||
|
||||
sql_statements = (
|
||||
SparkSession()
|
||||
SparkSession
|
||||
.builder
|
||||
.config("sqlframe.dialect", "snowflake")
|
||||
.getOrCreate()
|
||||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("lname")).alias("num_employees"))
|
||||
.sql(dialect="snowflake")
|
||||
.sql()
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -210,7 +220,7 @@ sql_statements = (
|
|||
.createDataFrame(data, schema)
|
||||
.groupBy(F.col("age"))
|
||||
.agg(F.countDistinct(F.col("employee_id")).alias("num_employees"))
|
||||
.sql(dialect="spark")
|
||||
.sql()
|
||||
)
|
||||
|
||||
pyspark = PySparkSession.builder.master("local[*]").getOrCreate()
|
||||
|
|
|
@ -5,7 +5,6 @@ import typing as t
|
|||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.types import DataType
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import flatten, is_iterable
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -15,19 +14,20 @@ if t.TYPE_CHECKING:
|
|||
|
||||
class Column:
|
||||
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
if isinstance(expression, Column):
|
||||
expression = expression.expression # type: ignore
|
||||
elif expression is None or not isinstance(expression, (str, exp.Expression)):
|
||||
expression = self._lit(expression).expression # type: ignore
|
||||
|
||||
expression = sqlglot.maybe_parse(expression, dialect="spark")
|
||||
elif not isinstance(expression, exp.Column):
|
||||
expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
|
||||
SparkSession().dialect.normalize_identifier, copy=False
|
||||
)
|
||||
if expression is None:
|
||||
raise ValueError(f"Could not parse {expression}")
|
||||
|
||||
if isinstance(expression, exp.Column):
|
||||
expression.transform(Spark.normalize_identifier, copy=False)
|
||||
|
||||
self.expression: exp.Expression = expression
|
||||
self.expression: exp.Expression = expression # type: ignore
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.expression)
|
||||
|
@ -207,7 +207,9 @@ class Column:
|
|||
return Column(expression)
|
||||
|
||||
def sql(self, **kwargs) -> str:
|
||||
return self.expression.sql(**{"dialect": "spark", **kwargs})
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
|
||||
|
||||
def alias(self, name: str) -> Column:
|
||||
new_expression = exp.alias_(self.column_expression, name)
|
||||
|
@ -264,9 +266,11 @@ class Column:
|
|||
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
|
||||
Sqlglot doesn't currently replicate this class so it only accepts a string
|
||||
"""
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
if isinstance(dataType, DataType):
|
||||
dataType = dataType.simpleString()
|
||||
return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
|
||||
return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
|
||||
|
||||
def startswith(self, value: t.Union[str, Column]) -> Column:
|
||||
value = self._lit(value) if not isinstance(value, Column) else value
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import typing as t
|
||||
import zlib
|
||||
from copy import copy
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot import Dialect, expressions as exp
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.group import GroupedData
|
||||
|
@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
|||
from sqlglot.dataframe.sql.window import Window
|
||||
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
||||
from sqlglot.optimizer import optimize as optimize_func
|
||||
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import (
|
||||
|
@ -27,7 +29,9 @@ if t.TYPE_CHECKING:
|
|||
OutputExpressionContainer,
|
||||
)
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
JOIN_HINTS = {
|
||||
"BROADCAST",
|
||||
|
@ -264,7 +268,9 @@ class DataFrame:
|
|||
|
||||
@classmethod
|
||||
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
|
||||
value = expression.sql(dialect="spark").encode("utf-8")
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
|
||||
return f"t{zlib.crc32(value)}"[:6]
|
||||
|
||||
def _get_select_expressions(
|
||||
|
@ -291,7 +297,15 @@ class DataFrame:
|
|||
select_expressions.append(expression_select_pair) # type: ignore
|
||||
return select_expressions
|
||||
|
||||
def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
|
||||
def sql(
|
||||
self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
|
||||
) -> t.List[str]:
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
|
||||
logger.warning(
|
||||
f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
|
||||
)
|
||||
df = self._resolve_pending_hints()
|
||||
select_expressions = df._get_select_expressions()
|
||||
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
||||
|
@ -299,7 +313,10 @@ class DataFrame:
|
|||
for expression_type, select_expression in select_expressions:
|
||||
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
||||
if optimize:
|
||||
select_expression = t.cast(exp.Select, optimize_func(select_expression))
|
||||
quote_identifiers(select_expression)
|
||||
select_expression = t.cast(
|
||||
exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
|
||||
)
|
||||
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
||||
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
||||
if expression_type == exp.Cache:
|
||||
|
@ -313,10 +330,12 @@ class DataFrame:
|
|||
sqlglot.schema.add_table(
|
||||
cache_table_name,
|
||||
{
|
||||
expression.alias_or_name: expression.type.sql("spark")
|
||||
expression.alias_or_name: expression.type.sql(
|
||||
dialect=SparkSession().dialect
|
||||
)
|
||||
for expression in select_expression.expressions
|
||||
},
|
||||
dialect="spark",
|
||||
dialect=SparkSession().dialect,
|
||||
)
|
||||
cache_storage_level = select_expression.args["cache_storage_level"]
|
||||
options = [
|
||||
|
@ -345,7 +364,8 @@ class DataFrame:
|
|||
output_expressions.append(expression)
|
||||
|
||||
return [
|
||||
expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
|
||||
expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
|
||||
for expression in output_expressions
|
||||
]
|
||||
|
||||
def copy(self, **kwargs) -> DataFrame:
|
||||
|
|
|
@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
|
||||
if ignorenulls is not None:
|
||||
return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
|
||||
return Column.invoke_anonymous_function(col, "FIRST")
|
||||
return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
|
||||
|
||||
|
||||
def grouping_id(*cols: ColumnOrName) -> Column:
|
||||
|
@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column:
|
|||
|
||||
|
||||
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
|
||||
if ignorenulls is not None:
|
||||
return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
|
||||
return Column.invoke_anonymous_function(col, "LAST")
|
||||
return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
|
||||
|
||||
|
||||
def monotonically_increasing_id() -> Column:
|
||||
|
|
|
@ -5,7 +5,6 @@ import typing as t
|
|||
from sqlglot import expressions as exp
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import ensure_list
|
||||
|
||||
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
||||
|
@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
|
|||
for expression in expressions:
|
||||
identifiers = expression.find_all(exp.Identifier)
|
||||
for identifier in identifiers:
|
||||
Spark.normalize_identifier(identifier)
|
||||
identifier.transform(spark.dialect.normalize_identifier)
|
||||
replace_alias_name_with_cte_name(spark, expression_context, identifier)
|
||||
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import typing as t
|
|||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot.dialects import Spark
|
||||
from sqlglot.helper import object_to_dict
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -18,15 +17,25 @@ class DataFrameReader:
|
|||
|
||||
def table(self, tableName: str) -> DataFrame:
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
sqlglot.schema.add_table(tableName, dialect="spark")
|
||||
sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
|
||||
|
||||
return DataFrame(
|
||||
self.spark,
|
||||
exp.Select()
|
||||
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
|
||||
.from_(
|
||||
exp.to_table(tableName, dialect=SparkSession().dialect).transform(
|
||||
SparkSession().dialect.normalize_identifier
|
||||
)
|
||||
)
|
||||
.select(
|
||||
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
|
||||
*(
|
||||
column
|
||||
for column in sqlglot.schema.column_names(
|
||||
tableName, dialect=SparkSession().dialect
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -63,6 +72,8 @@ class DataFrameWriter:
|
|||
return self.copy(by_name=True)
|
||||
|
||||
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
output_expression_container = exp.Insert(
|
||||
**{
|
||||
"this": exp.to_table(tableName),
|
||||
|
@ -71,7 +82,9 @@ class DataFrameWriter:
|
|||
)
|
||||
df = self._df.copy(output_expression_container=output_expression_container)
|
||||
if self._by_name:
|
||||
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
|
||||
columns = sqlglot.schema.column_names(
|
||||
tableName, only_visible=True, dialect=SparkSession().dialect
|
||||
)
|
||||
df = df._convert_leaf_to_cte().select(*columns)
|
||||
|
||||
return self.copy(_df=df)
|
||||
|
|
|
@ -5,31 +5,35 @@ import uuid
|
|||
from collections import defaultdict
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import expressions as exp
|
||||
from sqlglot import Dialect, expressions as exp
|
||||
from sqlglot.dataframe.sql import functions as F
|
||||
from sqlglot.dataframe.sql.dataframe import DataFrame
|
||||
from sqlglot.dataframe.sql.readwriter import DataFrameReader
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
|
||||
from sqlglot.helper import classproperty
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
|
||||
|
||||
|
||||
class SparkSession:
|
||||
known_ids: t.ClassVar[t.Set[str]] = set()
|
||||
known_branch_ids: t.ClassVar[t.Set[str]] = set()
|
||||
known_sequence_ids: t.ClassVar[t.Set[str]] = set()
|
||||
name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
|
||||
DEFAULT_DIALECT = "spark"
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
self.incrementing_id = 1
|
||||
if not hasattr(self, "known_ids"):
|
||||
self.known_ids = set()
|
||||
self.known_branch_ids = set()
|
||||
self.known_sequence_ids = set()
|
||||
self.name_to_sequence_id_mapping = defaultdict(list)
|
||||
self.incrementing_id = 1
|
||||
self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
|
||||
|
||||
def __getattr__(self, name: str) -> SparkSession:
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> SparkSession:
|
||||
return self
|
||||
def __new__(cls, *args, **kwargs) -> SparkSession:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@property
|
||||
def read(self) -> DataFrameReader:
|
||||
|
@ -101,7 +105,7 @@ class SparkSession:
|
|||
return DataFrame(self, sel_expression)
|
||||
|
||||
def sql(self, sqlQuery: str) -> DataFrame:
|
||||
expression = sqlglot.parse_one(sqlQuery, read="spark")
|
||||
expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
|
||||
if isinstance(expression, exp.Select):
|
||||
df = DataFrame(self, expression)
|
||||
df = df._convert_leaf_to_cte()
|
||||
|
@ -149,3 +153,38 @@ class SparkSession:
|
|||
|
||||
def _add_alias_to_mapping(self, name: str, sequence_id: str):
|
||||
self.name_to_sequence_id_mapping[name].append(sequence_id)
|
||||
|
||||
class Builder:
|
||||
SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
|
||||
|
||||
def __init__(self):
|
||||
self.dialect = "spark"
|
||||
|
||||
def __getattr__(self, item) -> SparkSession.Builder:
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def config(
|
||||
self,
|
||||
key: t.Optional[str] = None,
|
||||
value: t.Optional[t.Any] = None,
|
||||
*,
|
||||
map: t.Optional[t.Dict[str, t.Any]] = None,
|
||||
**kwargs: t.Any,
|
||||
) -> SparkSession.Builder:
|
||||
if key == self.SQLFRAME_DIALECT_KEY:
|
||||
self.dialect = value
|
||||
elif map and self.SQLFRAME_DIALECT_KEY in map:
|
||||
self.dialect = map[self.SQLFRAME_DIALECT_KEY]
|
||||
return self
|
||||
|
||||
def getOrCreate(self) -> SparkSession:
|
||||
spark = SparkSession()
|
||||
spark.dialect = Dialect.get_or_raise(self.dialect)()
|
||||
return spark
|
||||
|
||||
@classproperty
|
||||
def builder(cls) -> Builder:
|
||||
return cls.Builder()
|
||||
|
|
|
@ -48,7 +48,9 @@ class WindowSpec:
|
|||
return WindowSpec(self.expression.copy())
|
||||
|
||||
def sql(self, **kwargs) -> str:
|
||||
return self.expression.sql(dialect="spark", **kwargs)
|
||||
from sqlglot.dataframe.sql.session import SparkSession
|
||||
|
||||
return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
|
||||
|
||||
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
|
||||
from sqlglot.dataframe.sql.column import Column
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
json_keyvalue_comma_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
no_ilike_sql,
|
||||
|
@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot")
|
|||
|
||||
def _date_add_sql(
|
||||
data_type: str, kind: str
|
||||
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||
def func(self, expression):
|
||||
) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
|
||||
def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.args.get("unit")
|
||||
unit = exp.var(unit.name.upper() if unit else "DAY")
|
||||
|
@ -40,7 +41,7 @@ def _date_add_sql(
|
|||
return func
|
||||
|
||||
|
||||
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||
def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
|
||||
if not expression.find_ancestor(exp.From, exp.Join):
|
||||
return self.values_sql(expression)
|
||||
|
||||
|
@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V
|
|||
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
|
||||
|
||||
|
||||
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Schema):
|
||||
this = f"{this.this} <{self.expressions(this)}>"
|
||||
|
@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
|
|||
return f"RETURNS {this}"
|
||||
|
||||
|
||||
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
|
||||
def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
|
||||
kind = expression.args["kind"]
|
||||
returns = expression.find(exp.ReturnsProperty)
|
||||
|
||||
|
@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
These are added by the optimizer's qualify_column step.
|
||||
"""
|
||||
from sqlglot.optimizer.scope import Scope
|
||||
from sqlglot.optimizer.scope import find_all_in_scope
|
||||
|
||||
if isinstance(expression, exp.Select):
|
||||
for unnest in expression.find_all(exp.Unnest):
|
||||
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
|
||||
for column in Scope(expression).find_all(exp.Column):
|
||||
if column.table == unnest.alias:
|
||||
column.set("table", None)
|
||||
unnest_aliases = {
|
||||
unnest.alias
|
||||
for unnest in find_all_in_scope(expression, exp.Unnest)
|
||||
if isinstance(unnest.parent, (exp.From, exp.Join))
|
||||
}
|
||||
if unnest_aliases:
|
||||
for column in expression.find_all(exp.Column):
|
||||
if column.table in unnest_aliases:
|
||||
column.set("table", None)
|
||||
elif column.db in unnest_aliases:
|
||||
column.set("db", None)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -261,6 +268,7 @@ class BigQuery(Dialect):
|
|||
"TIMESTAMP": TokenType.TIMESTAMPTZ,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
|
@ -270,6 +278,8 @@ class BigQuery(Dialect):
|
|||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE": _parse_date,
|
||||
|
@ -299,6 +309,8 @@ class BigQuery(Dialect):
|
|||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
|
||||
"SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
|
||||
"SPLIT": lambda args: exp.Split(
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
|
||||
this=seq_get(args, 0),
|
||||
|
@ -346,7 +358,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_table_part(schema=schema)
|
||||
this = super()._parse_table_part(schema=schema) or self._parse_number()
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
|
||||
if isinstance(this, exp.Identifier):
|
||||
|
@ -356,6 +368,17 @@ class BigQuery(Dialect):
|
|||
table_name += f"-{self._prev.text}"
|
||||
|
||||
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
|
||||
elif isinstance(this, exp.Literal):
|
||||
table_name = this.name
|
||||
|
||||
if (
|
||||
self._curr
|
||||
and self._prev.end == self._curr.start - 1
|
||||
and self._parse_var(any_token=True)
|
||||
):
|
||||
table_name += self._prev.text
|
||||
|
||||
this = exp.Identifier(this=table_name, quoted=True)
|
||||
|
||||
return this
|
||||
|
||||
|
@ -374,6 +397,27 @@ class BigQuery(Dialect):
|
|||
|
||||
return table
|
||||
|
||||
def _parse_json_object(self) -> exp.JSONObject:
|
||||
json_object = super()._parse_json_object()
|
||||
array_kv_pair = seq_get(json_object.expressions, 0)
|
||||
|
||||
# Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2
|
||||
if (
|
||||
array_kv_pair
|
||||
and isinstance(array_kv_pair.this, exp.Array)
|
||||
and isinstance(array_kv_pair.expression, exp.Array)
|
||||
):
|
||||
keys = array_kv_pair.this.expressions
|
||||
values = array_kv_pair.expression.expressions
|
||||
|
||||
json_object.set(
|
||||
"expressions",
|
||||
[exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)],
|
||||
)
|
||||
|
||||
return json_object
|
||||
|
||||
class Generator(generator.Generator):
|
||||
EXPLICIT_UNION = True
|
||||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
|
@ -383,6 +427,7 @@ class BigQuery(Dialect):
|
|||
LIMIT_FETCH = "LIMIT"
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
ESCAPE_LINE_BREAK = True
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -405,6 +450,7 @@ class BigQuery(Dialect):
|
|||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.JSONFormat: rename_func("TO_JSON_STRING"),
|
||||
exp.JSONKeyValue: json_keyvalue_comma_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
|
||||
exp.MD5Digest: rename_func("MD5"),
|
||||
|
@ -428,6 +474,9 @@ class BigQuery(Dialect):
|
|||
_alias_ordered_group,
|
||||
]
|
||||
),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
|
||||
),
|
||||
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
|
@ -591,6 +640,13 @@ class BigQuery(Dialect):
|
|||
|
||||
return super().attimezone_sql(expression)
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals
|
||||
if expression.is_type("json"):
|
||||
return f"JSON {self.sql(expression, 'this')}"
|
||||
|
||||
return super().cast_sql(expression, safe_prefix=safe_prefix)
|
||||
|
||||
def trycast_sql(self, expression: exp.TryCast) -> str:
|
||||
return self.cast_sql(expression, safe_prefix="SAFE_")
|
||||
|
||||
|
@ -630,3 +686,9 @@ class BigQuery(Dialect):
|
|||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(properties, prefix=self.seg("OPTIONS"))
|
||||
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
if expression.name == "TIMESTAMP":
|
||||
expression = expression.copy()
|
||||
expression.set("this", "SYSTEM_TIME")
|
||||
return super().version_sql(expression)
|
||||
|
|
|
@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
|
|||
var_map_sql,
|
||||
)
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
|
@ -63,9 +64,23 @@ class ClickHouse(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ANY": exp.AnyValue.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_DIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"MAP": parse_var_map,
|
||||
"MATCH": exp.RegexpLike.from_arg_list,
|
||||
"UNIQ": exp.ApproxDistinct.from_arg_list,
|
||||
|
@ -147,7 +162,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
this = self._parse_id_var()
|
||||
self._match(TokenType.COLON)
|
||||
kind = self._parse_types(check_func=False) or (
|
||||
kind = self._parse_types(check_func=False, allow_identifiers=False) or (
|
||||
self._match_text_seq("IDENTIFIER") and "Identifier"
|
||||
)
|
||||
|
||||
|
@ -249,7 +264,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
def _parse_func_params(
|
||||
self, this: t.Optional[exp.Func] = None
|
||||
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
) -> t.Optional[t.List[exp.Expression]]:
|
||||
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
|
||||
return self._parse_csv(self._parse_lambda)
|
||||
|
||||
|
@ -267,9 +282,7 @@ class ClickHouse(Dialect):
|
|||
return self.expression(exp.Quantile, this=params[0], quantile=this)
|
||||
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
|
||||
|
||||
def _parse_wrapped_id_vars(
|
||||
self, optional: bool = False
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
|
||||
return super()._parse_wrapped_id_vars(optional=True)
|
||||
|
||||
def _parse_primary_key(
|
||||
|
@ -292,9 +305,22 @@ class ClickHouse(Dialect):
|
|||
class Generator(generator.Generator):
|
||||
QUERY_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
exp.DataType.Type.LONGBLOB: "String",
|
||||
exp.DataType.Type.LONGTEXT: "String",
|
||||
exp.DataType.Type.MEDIUMBLOB: "String",
|
||||
exp.DataType.Type.MEDIUMTEXT: "String",
|
||||
exp.DataType.Type.TEXT: "String",
|
||||
exp.DataType.Type.VARBINARY: "String",
|
||||
exp.DataType.Type.VARCHAR: "String",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**STRING_TYPE_MAPPING,
|
||||
exp.DataType.Type.ARRAY: "Array",
|
||||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATETIME64: "DateTime64",
|
||||
|
@ -328,6 +354,12 @@ class ClickHouse(Dialect):
|
|||
exp.ApproxDistinct: rename_func("uniq"),
|
||||
exp.Array: inline_array_sql,
|
||||
exp.CastToStrType: rename_func("CAST"),
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
|
@ -364,6 +396,16 @@ class ClickHouse(Dialect):
|
|||
"NAMED COLLECTION",
|
||||
}
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
# String is the standard ClickHouse type, every other variant is just an alias.
|
||||
# Additionally, any supplied length parameter will be ignored.
|
||||
#
|
||||
# https://clickhouse.com/docs/en/sql-reference/data-types/string
|
||||
if expression.this in self.STRING_TYPE_MAPPING:
|
||||
return "String"
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
# Clickhouse errors out if we try to cast a NULL value to TEXT
|
||||
expression = expression.copy()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import parse_date_delta
|
||||
from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
|
||||
from sqlglot.dialects.spark import Spark
|
||||
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -28,6 +28,19 @@ class Databricks(Spark):
|
|||
**Spark.Generator.TRANSFORMS,
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.DatetimeAdd: lambda self, e: self.func(
|
||||
"TIMESTAMPADD", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.DatetimeSub: lambda self, e: self.func(
|
||||
"TIMESTAMPADD",
|
||||
e.text("unit"),
|
||||
exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
|
||||
e.this,
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, e: self.func(
|
||||
"TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
|
||||
),
|
||||
exp.DatetimeTrunc: timestamptrunc_sql,
|
||||
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
|
|
@ -109,8 +109,7 @@ class _Dialect(type):
|
|||
for k, v in vars(klass).items()
|
||||
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
|
||||
},
|
||||
"STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
|
||||
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||
"TOKENIZER_CLASS": klass.tokenizer_class,
|
||||
}
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
|
@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql(
|
|||
|
||||
|
||||
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
||||
return f"[{self.expressions(expression)}]"
|
||||
return f"[{self.expressions(expression, flat=True)}]"
|
||||
|
||||
|
||||
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||
|
@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
|||
|
||||
|
||||
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
|
||||
return f"{this}.{struct_key}"
|
||||
return (
|
||||
f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
|
||||
)
|
||||
|
||||
|
||||
def var_map_sql(
|
||||
|
@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
|
|||
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
|
||||
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
|
||||
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
|
||||
return self.func("MAX", expression.this)
|
||||
|
||||
|
||||
# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
|
||||
def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
|
||||
return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
|
||||
|
|
|
@ -37,7 +37,6 @@ class Doris(MySQL):
|
|||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.Coalesce: rename_func("NVL"),
|
||||
exp.CurrentTimestamp: lambda *_: "NOW()",
|
||||
exp.DateTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
|
|
|
@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = exp.var(expression.text("unit").upper() or "DAY")
|
||||
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
|
||||
|
@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
return func
|
||||
|
||||
|
||||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Drill.DATE_FORMAT:
|
||||
|
@ -73,7 +73,6 @@ class Drill(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'"]
|
||||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
@ -81,6 +80,7 @@ class Drill(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -95,6 +95,7 @@ class Drill(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
|
|
@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
|
|||
datestrtodate_sql,
|
||||
encode_decode_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
no_comment_column_constraint_sql,
|
||||
no_properties_sql,
|
||||
no_safe_divide_sql,
|
||||
|
@ -30,13 +31,13 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
|
||||
|
||||
|
||||
def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||
op = "+" if isinstance(expression, exp.DateAdd) else "-"
|
||||
|
@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat
|
|||
|
||||
|
||||
# BigQuery -> DuckDB conversion for the DATE function
|
||||
def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
|
||||
def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
|
||||
result = f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
zone = self.sql(expression, "zone")
|
||||
|
||||
|
@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
|
|||
return result
|
||||
|
||||
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
|
||||
return f"ARRAY_SORT({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if expression.args.get("asc") == exp.false():
|
||||
return f"ARRAY_REVERSE_SORT({this})"
|
||||
|
@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
|
|||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
||||
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
|
||||
args = [
|
||||
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
|
||||
]
|
||||
return f"{{{', '.join(args)}}}"
|
||||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
|
||||
|
@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
|
||||
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
|
||||
return f"CAST({sql} AS TEXT)"
|
||||
|
||||
|
@ -134,6 +135,7 @@ class DuckDB(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
BITWISE = {
|
||||
**parser.Parser.BITWISE,
|
||||
|
@ -183,18 +185,12 @@ class DuckDB(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
*parser.Parser.TYPE_TOKENS,
|
||||
TokenType.UBIGINT,
|
||||
TokenType.UINT,
|
||||
TokenType.USMALLINT,
|
||||
TokenType.UTINYINT,
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
this = super()._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
|
||||
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
|
||||
# See: https://duckdb.org/docs/sql/data_types/numeric
|
||||
|
@ -207,6 +203,9 @@ class DuckDB(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_field_def()
|
||||
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
if len(aggregations) == 1:
|
||||
return super()._pivot_column_names(aggregations)
|
||||
|
@ -219,13 +218,14 @@ class DuckDB(Dialect):
|
|||
LIMIT_FETCH = "LIMIT"
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
RENAME_TABLE_WITH_DB = False
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
|
||||
if e.expressions and e.expressions[0].find(exp.Select)
|
||||
else rename_func("LIST_VALUE")(self, e),
|
||||
else inline_array_sql(self, e),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
|
|
|
@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
|
||||
|
@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
|
|||
return self.func(func, expression.this, modified_increment)
|
||||
|
||||
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
|
||||
factor = TIME_DIFF_FACTOR.get(unit)
|
||||
|
@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
|
||||
# Since FROM_JSON requires a nested type, we always wrap the json string with
|
||||
|
@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
|
|||
return self.func("TO_JSON", this, expression.args.get("options"))
|
||||
|
||||
|
||||
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
|
||||
def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
|
||||
if expression.expression:
|
||||
self.unsupported("Hive SORT_ARRAY does not support a comparator")
|
||||
return f"SORT_ARRAY({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
||||
def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
|
||||
return f"'{expression.name}'={self.sql(expression, 'value')}"
|
||||
|
||||
|
||||
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
|
||||
|
||||
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
|
@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
|
|||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
||||
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
|
@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
|
|||
return f"CAST({this} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
|
||||
def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
return f"DATE_FORMAT({this}, {time_format})"
|
||||
|
||||
|
||||
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
|
@ -206,6 +206,8 @@ class Hive(Dialect):
|
|||
"MSCK REPAIR": TokenType.COMMAND,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
|
||||
"TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
"VERSION AS OF": TokenType.VERSION_SNAPSHOT,
|
||||
}
|
||||
|
||||
NUMERIC_LITERALS = {
|
||||
|
@ -220,6 +222,7 @@ class Hive(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
STRICT_CAST = False
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -257,6 +260,11 @@ class Hive(Dialect):
|
|||
),
|
||||
"SIZE": exp.ArraySize.from_arg_list,
|
||||
"SPLIT": exp.RegexpSplit.from_arg_list,
|
||||
"STR_TO_MAP": lambda args: exp.StrToMap(
|
||||
this=seq_get(args, 0),
|
||||
pair_delim=seq_get(args, 1) or exp.Literal.string(","),
|
||||
key_value_delim=seq_get(args, 2) or exp.Literal.string(":"),
|
||||
),
|
||||
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
|
||||
"TO_JSON": exp.JSONFormat.from_arg_list,
|
||||
"UNBASE64": exp.FromBase64.from_arg_list,
|
||||
|
@ -313,7 +321,7 @@ class Hive(Dialect):
|
|||
)
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
"""
|
||||
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
|
||||
|
@ -333,7 +341,9 @@ class Hive(Dialect):
|
|||
|
||||
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
|
||||
"""
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
this = super()._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
|
||||
if this and not schema:
|
||||
return this.transform(
|
||||
|
@ -345,6 +355,16 @@ class Hive(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_partition_and_order(
|
||||
self,
|
||||
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
|
||||
return (
|
||||
self._parse_csv(self._parse_conjunction)
|
||||
if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
|
||||
else [],
|
||||
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LIMIT_FETCH = "LIMIT"
|
||||
TABLESAMPLE_WITH_METHOD = False
|
||||
|
@ -354,6 +374,7 @@ class Hive(Dialect):
|
|||
QUERY_HINTS = False
|
||||
INDEX_ON = "ON TABLE"
|
||||
EXTRACT_ALLOWS_QUOTES = False
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -376,6 +397,7 @@ class Hive(Dialect):
|
|||
]
|
||||
),
|
||||
exp.Property: _property_sql,
|
||||
exp.AnyValue: rename_func("FIRST"),
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayConcat: rename_func("CONCAT"),
|
||||
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
|
||||
|
@ -402,6 +424,9 @@ class Hive(Dialect):
|
|||
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
|
||||
exp.Min: min_or_least,
|
||||
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
|
||||
exp.NotNullColumnConstraint: lambda self, e: ""
|
||||
if e.args.get("allow_null")
|
||||
else "NOT NULL",
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
|
@ -472,7 +497,7 @@ class Hive(Dialect):
|
|||
elif expression.this in exp.DataType.TEMPORAL_TYPES:
|
||||
expression = exp.DataType.build(expression.this)
|
||||
elif expression.is_type("float"):
|
||||
size_expression = expression.find(exp.DataTypeSize)
|
||||
size_expression = expression.find(exp.DataTypeParam)
|
||||
if size_expression:
|
||||
size = int(size_expression.name)
|
||||
expression = (
|
||||
|
@ -480,3 +505,7 @@ class Hive(Dialect):
|
|||
)
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
sql = super().version_sql(expression)
|
||||
return sql.replace("FOR ", "", 1)
|
||||
|
|
|
@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
|
|||
arrow_json_extract_scalar_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
json_keyvalue_comma_sql,
|
||||
locate_to_strposition,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
|
|||
return _parse
|
||||
|
||||
|
||||
def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
|
||||
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
|
||||
expr = self.sql(expression, "this")
|
||||
unit = expression.text("unit")
|
||||
|
||||
|
@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
|
|||
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||
|
||||
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
|
||||
def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
|
||||
date_format = self.format_time(expression)
|
||||
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
|
||||
|
||||
|
||||
def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
|
||||
def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
|
||||
target = self.sql(expression, "this")
|
||||
trim_type = self.sql(expression, "position")
|
||||
remove_chars = self.sql(expression, "expression")
|
||||
|
@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
|
||||
|
@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
|
||||
|
||||
class MySQL(Dialect):
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
||||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
|
||||
|
@ -129,6 +133,7 @@ class MySQL(Dialect):
|
|||
"LONGTEXT": TokenType.LONGTEXT,
|
||||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"MEDIUMINT": TokenType.MEDIUMINT,
|
||||
"MEMBER OF": TokenType.MEMBER_OF,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"START": TokenType.BEGIN,
|
||||
|
@ -136,6 +141,7 @@ class MySQL(Dialect):
|
|||
"SIGNED INTEGER": TokenType.BIGINT,
|
||||
"UNSIGNED": TokenType.UBIGINT,
|
||||
"UNSIGNED INTEGER": TokenType.UBIGINT,
|
||||
"YEAR": TokenType.YEAR,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
"_BIG5": TokenType.INTRODUCER,
|
||||
|
@ -185,6 +191,8 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.DATABASE,
|
||||
|
@ -492,6 +500,17 @@ class MySQL(Dialect):
|
|||
|
||||
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
|
||||
|
||||
def _parse_type(self) -> t.Optional[exp.Expression]:
|
||||
# mysql binary is special and can work anywhere, even in order by operations
|
||||
# it operates like a no paren func
|
||||
if self._match(TokenType.BINARY, advance=False):
|
||||
data_type = self._parse_types(check_func=True, allow_identifiers=False)
|
||||
|
||||
if isinstance(data_type, exp.DataType):
|
||||
return self.expression(exp.Cast, this=self._parse_column(), to=data_type)
|
||||
|
||||
return super()._parse_type()
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
@ -500,6 +519,7 @@ class MySQL(Dialect):
|
|||
DUPLICATE_KEY_UPDATE_WITH_SET = False
|
||||
QUERY_HINT_SEP = " "
|
||||
VALUES_AS_TABLE = False
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -515,6 +535,7 @@ class MySQL(Dialect):
|
|||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONKeyValue: json_keyvalue_comma_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
|
@ -524,6 +545,7 @@ class MySQL(Dialect):
|
|||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
exp.StrToTime: _str_to_date_sql,
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
|
||||
def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
|
||||
this = self._parse_string()
|
||||
|
||||
passing = None
|
||||
|
@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
|
|||
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
|
||||
|
||||
if self._match_text_seq("COLUMNS"):
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
|
||||
columns = self._parse_csv(self._parse_field_def)
|
||||
|
||||
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
|
||||
|
||||
|
@ -78,6 +78,10 @@ class Oracle(Dialect):
|
|||
)
|
||||
}
|
||||
|
||||
# SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT ..
|
||||
# Reference: https://stackoverflow.com/a/336455
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
|
||||
|
||||
def _parse_column(self) -> t.Optional[exp.Expression]:
|
||||
column = super()._parse_column()
|
||||
if column:
|
||||
|
@ -129,7 +133,6 @@ class Oracle(Dialect):
|
|||
),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Coalesce: rename_func("NVL"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
|
@ -162,7 +165,7 @@ class Oracle(Dialect):
|
|||
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
VAR_SINGLE_TOKENS = {"@"}
|
||||
VAR_SINGLE_TOKENS = {"@", "$", "#"}
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
|
@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = {
|
|||
}
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
return func
|
||||
|
||||
|
||||
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
||||
def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
factor = DATE_DIFF_FACTOR.get(unit)
|
||||
|
||||
|
@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
return f"CAST({unit} AS BIGINT)"
|
||||
|
||||
|
||||
def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
|
||||
def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
start = self.sql(expression, "start")
|
||||
length = self.sql(expression, "length")
|
||||
|
@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
|
|||
return f"SUBSTRING({this}{from_part}{for_part})"
|
||||
|
||||
|
||||
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
|
||||
def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
separator = expression.args.get("separator") or exp.Literal.string(",")
|
||||
|
||||
|
@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
|
|||
return f"STRING_AGG({self.format_args(this, separator)}{order})"
|
||||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
return self.datatype_sql(expression)
|
||||
|
@ -254,6 +255,7 @@ class Postgres(Dialect):
|
|||
"~~*": TokenType.ILIKE,
|
||||
"~*": TokenType.IRLIKE,
|
||||
"~": TokenType.RLIKE,
|
||||
"@@": TokenType.DAT,
|
||||
"@>": TokenType.AT_GT,
|
||||
"<@": TokenType.LT_AT,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
|
@ -273,6 +275,18 @@ class Postgres(Dialect):
|
|||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
"OID": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGDICTIONARY": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGNAMESPACE": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGOPER": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGOPERATOR": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGPROC": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGROLE": TokenType.OBJECT_IDENTIFIER,
|
||||
"REGTYPE": TokenType.OBJECT_IDENTIFIER,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
|
@ -312,6 +326,9 @@ class Postgres(Dialect):
|
|||
RANGE_PARSERS = {
|
||||
**parser.Parser.RANGE_PARSERS,
|
||||
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
|
||||
TokenType.DAT: lambda self, this: self.expression(
|
||||
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
|
||||
),
|
||||
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
|
||||
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
|
||||
}
|
||||
|
@ -343,6 +360,7 @@ class Postgres(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
PARAMETER_TOKEN = "$"
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -357,6 +375,8 @@ class Postgres(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
|
@ -416,3 +436,9 @@ class Postgres(Dialect):
|
|||
expression.set("this", exp.paren(expression.this, copy=False))
|
||||
|
||||
return super().bracket_sql(expression)
|
||||
|
||||
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions]
|
||||
sql = " OR ".join(expressions)
|
||||
return f"({sql})" if len(expressions) > 1 else sql
|
||||
|
|
|
@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
|
||||
def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str:
|
||||
accuracy = expression.args.get("accuracy")
|
||||
accuracy = ", " + self.sql(accuracy) if accuracy else ""
|
||||
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
|
||||
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
expression = expression.copy()
|
||||
return self.sql(
|
||||
|
@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -
|
|||
return self.lateral_sql(expression)
|
||||
|
||||
|
||||
def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
|
||||
def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
|
||||
regex = r"(\w)(\w*)"
|
||||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
||||
|
||||
def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
|
||||
def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
|
@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
|
|||
return self.func("ARRAY_SORT", expression.this, comparator)
|
||||
|
||||
|
||||
def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
|
||||
def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
|
||||
if isinstance(expression.parent, exp.Property):
|
||||
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||
return f"ARRAY[{columns}]"
|
||||
|
@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
|
|||
return self.schema_sql(expression)
|
||||
|
||||
|
||||
def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
|
||||
def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str:
|
||||
self.unsupported("Presto does not support exact quantiles")
|
||||
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
|
||||
|
||||
|
||||
def _str_to_time_sql(
|
||||
self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
) -> str:
|
||||
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
|
||||
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
|
||||
return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
|
||||
|
||||
|
||||
def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not isinstance(this, exp.CurrentDate):
|
||||
|
@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
|
||||
"""
|
||||
Trino doesn't support FIRST / LAST as functions, but they're valid in the context
|
||||
of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
|
||||
they're converted into an ARBITRARY call.
|
||||
|
||||
Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions
|
||||
"""
|
||||
if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize):
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
return rename_func("ARBITRARY")(self, expression)
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
@ -178,6 +192,7 @@ class Presto(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARBITRARY": exp.AnyValue.from_arg_list,
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"APPROX_PERCENTILE": _approx_percentile,
|
||||
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
|
||||
|
@ -205,7 +220,14 @@ class Presto(Dialect):
|
|||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
|
||||
),
|
||||
"REGEXP_REPLACE": lambda args: exp.RegexpReplace(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
replacement=seq_get(args, 2) or exp.Literal.string(""),
|
||||
),
|
||||
"ROW": exp.Struct.from_arg_list,
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
|
||||
),
|
||||
|
@ -225,6 +247,7 @@ class Presto(Dialect):
|
|||
QUERY_HINTS = False
|
||||
IS_BOOL_ALLOWED = False
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
NVL2_SUPPORTED = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -242,10 +265,13 @@ class Presto(Dialect):
|
|||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.STRUCT: "ROW",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.DATETIME64: "TIMESTAMP",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: rename_func("ARBITRARY"),
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
|
@ -268,15 +294,23 @@ class Presto(Dialect):
|
|||
),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "day"),
|
||||
e.expression * -1,
|
||||
e.this,
|
||||
),
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.First: _first_last_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
exp.Last: _first_last_sql,
|
||||
exp.Lateral: _explode_to_unnest_sql,
|
||||
exp.Left: left_to_substring_sql,
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
|
@ -301,8 +335,10 @@ class Presto(Dialect):
|
|||
exp.SortArray: _no_sort_array,
|
||||
exp.StrPosition: rename_func("STRPOS"),
|
||||
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
|
||||
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
|
||||
exp.StrToTime: _str_to_time_sql,
|
||||
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: rename_func("ROW"),
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Table: transforms.preprocess([_unnest_sequence]),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
|
|
|
@ -13,7 +13,7 @@ from sqlglot.helper import seq_get
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
|
||||
def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
|
||||
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
|
||||
|
||||
|
||||
|
@ -37,6 +37,8 @@ class Redshift(Postgres):
|
|||
}
|
||||
|
||||
class Parser(Postgres.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**Postgres.Parser.FUNCTIONS,
|
||||
"ADD_MONTHS": lambda args: exp.DateAdd(
|
||||
|
@ -55,9 +57,11 @@ class Redshift(Postgres):
|
|||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
this = super()._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(this, exp.DataType)
|
||||
|
@ -100,6 +104,7 @@ class Redshift(Postgres):
|
|||
QUERY_HINTS = False
|
||||
VALUES_AS_TABLE = False
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
NVL2_SUPPORTED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
|
@ -142,6 +147,9 @@ class Redshift(Postgres):
|
|||
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
|
||||
TRANSFORMS.pop(exp.Pow)
|
||||
|
||||
# Redshift supports ANY_VALUE(..)
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
|
||||
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
|
|
|
@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
|
|||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
|
||||
def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale in [None, exp.UnixToTime.SECONDS]:
|
||||
|
@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) ->
|
|||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
|
||||
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
|
||||
def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
|
||||
def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_var() or self._parse_type()
|
||||
|
||||
if not this:
|
||||
|
@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If:
|
|||
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
|
||||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return "ARRAY"
|
||||
elif expression.is_type("map"):
|
||||
|
@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
|||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
|
||||
flag = expression.text("flag")
|
||||
|
||||
if "i" not in flag:
|
||||
flag += "i"
|
||||
|
||||
return self.func(
|
||||
"REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag)
|
||||
)
|
||||
|
||||
|
||||
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
|
||||
if len(args) == 3:
|
||||
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
|
||||
|
@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
|
|||
return regexp_replace
|
||||
|
||||
|
||||
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]:
|
||||
def _parse(self: Snowflake.Parser) -> exp.Show:
|
||||
return self._parse_show_snowflake(*args, **kwargs)
|
||||
|
||||
return _parse
|
||||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
|
||||
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
|
||||
|
@ -216,6 +234,7 @@ class Snowflake(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -230,6 +249,7 @@ class Snowflake(Dialect):
|
|||
"DATEDIFF": _parse_datediff,
|
||||
"DIV0": _div0_to_if,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
"LISTAGG": exp.GroupConcat.from_arg_list,
|
||||
"NULLIFZERO": _nullifzero_to_if,
|
||||
"OBJECT_CONSTRUCT": _parse_object_construct,
|
||||
"REGEXP_REPLACE": _parse_regexp_replace,
|
||||
|
@ -250,11 +270,6 @@ class Snowflake(Dialect):
|
|||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.COLON: lambda self, this, path: self.expression(
|
||||
|
@ -281,6 +296,16 @@ class Snowflake(Dialect):
|
|||
),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
}
|
||||
|
||||
SHOW_PARSERS = {
|
||||
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
}
|
||||
|
||||
def _parse_id_var(
|
||||
self,
|
||||
any_token: bool = True,
|
||||
|
@ -296,8 +321,24 @@ class Snowflake(Dialect):
|
|||
|
||||
return super()._parse_id_var(any_token=any_token, tokens=tokens)
|
||||
|
||||
def _parse_show_snowflake(self, this: str) -> exp.Show:
|
||||
scope = None
|
||||
scope_kind = None
|
||||
|
||||
if self._match(TokenType.IN):
|
||||
if self._match_text_seq("ACCOUNT"):
|
||||
scope_kind = "ACCOUNT"
|
||||
elif self._match_set(self.DB_CREATABLES):
|
||||
scope_kind = self._prev.text
|
||||
if self._curr:
|
||||
scope = self._parse_table()
|
||||
elif self._curr:
|
||||
scope_kind = "TABLE"
|
||||
scope = self._parse_table()
|
||||
|
||||
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
RAW_STRINGS = ["$$"]
|
||||
|
@ -331,6 +372,8 @@ class Snowflake(Dialect):
|
|||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
||||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
PARAMETER_TOKEN = "$"
|
||||
MATCHED_BY_SOURCE = False
|
||||
|
@ -355,6 +398,7 @@ class Snowflake(Dialect):
|
|||
exp.DataType: _datatype_sql,
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.Extract: rename_func("DATE_PART"),
|
||||
exp.GroupConcat: rename_func("LISTAGG"),
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
|
||||
exp.LogicalOr: rename_func("BOOLOR_AGG"),
|
||||
|
@ -362,6 +406,7 @@ class Snowflake(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.RegexpILike: _regexpilike_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
|
@ -373,6 +418,7 @@ class Snowflake(Dialect):
|
|||
"OBJECT_CONSTRUCT",
|
||||
*(arg for expression in e.expressions for arg in expression.flatten()),
|
||||
),
|
||||
exp.Stuff: rename_func("INSERT"),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: lambda self, e: self.func(
|
||||
|
@ -403,6 +449,16 @@ class Snowflake(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def show_sql(self, expression: exp.Show) -> str:
|
||||
scope = self.sql(expression, "scope")
|
||||
scope = f" {scope}" if scope else ""
|
||||
|
||||
scope_kind = self.sql(expression, "scope_kind")
|
||||
if scope_kind:
|
||||
scope_kind = f" IN {scope_kind}"
|
||||
|
||||
return f"SHOW {expression.name}{scope_kind}{scope}"
|
||||
|
||||
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
|
||||
# Other dialects don't support all of the following parameters, so we need to
|
||||
# generate default values as necessary to ensure the transpilation is correct
|
||||
|
@ -436,7 +492,9 @@ class Snowflake(Dialect):
|
|||
kind_value = expression.args.get("kind") or "TABLE"
|
||||
kind = f" {kind_value}" if kind_value else ""
|
||||
this = f" {self.sql(expression, 'this')}"
|
||||
return f"DESCRIBE{kind}{this}"
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
expressions = f" {expressions}" if expressions else ""
|
||||
return f"DESCRIBE{kind}{this}{expressions}"
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
|
|
|
@ -38,9 +38,15 @@ class Spark(Spark2):
|
|||
class Parser(Spark2.Parser):
|
||||
FUNCTIONS = {
|
||||
**Spark2.Parser.FUNCTIONS,
|
||||
"ANY_VALUE": lambda args: exp.AnyValue(
|
||||
this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
|
||||
),
|
||||
"DATEDIFF": _parse_datediff,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("ANY_VALUE")
|
||||
|
||||
class Generator(Spark2.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
|
@ -56,9 +62,13 @@ class Spark(Spark2):
|
|||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.AnyValue)
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
||||
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
def datediff_sql(self, expression: exp.DateDiff) -> str:
|
||||
unit = self.sql(expression, "unit")
|
||||
end = self.sql(expression, "this")
|
||||
|
|
|
@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive
|
|||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
||||
def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
|
@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
|
|||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
|
||||
keys = self.sql(expression.args["keys"])
|
||||
values = self.sql(expression.args["values"])
|
||||
return f"MAP_FROM_ARRAYS({keys}, {values})"
|
||||
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
|
||||
keys = expression.args.get("keys")
|
||||
values = expression.args.get("values")
|
||||
|
||||
if not keys or not values:
|
||||
return "MAP()"
|
||||
|
||||
return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
|
||||
|
||||
|
||||
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
|
||||
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
|
||||
|
||||
|
||||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.DATE_FORMAT:
|
||||
|
@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
|||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
|
||||
def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
|
||||
scale = expression.args.get("scale")
|
||||
timestamp = self.sql(expression, "this")
|
||||
if scale is None:
|
||||
|
@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str:
|
||||
if expression.expression.args.get("with"):
|
||||
expression = expression.copy()
|
||||
expression.set("with", expression.expression.args.pop("with"))
|
||||
return self.insert_sql(expression)
|
||||
|
||||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
FUNCTIONS = {
|
||||
|
@ -169,10 +180,7 @@ class Spark2(Hive):
|
|||
|
||||
class Generator(Hive.Generator):
|
||||
QUERY_HINTS = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING,
|
||||
}
|
||||
NVL2_SUPPORTED = True
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**Hive.Generator.PROPERTIES_LOCATION,
|
||||
|
@ -197,6 +205,7 @@ class Spark2(Hive):
|
|||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
|
||||
exp.From: transforms.preprocess([_unalias_pivot]),
|
||||
exp.Insert: _insert_sql,
|
||||
exp.LogicalAnd: rename_func("BOOL_AND"),
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Map: _map_sql,
|
||||
|
|
|
@ -5,6 +5,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
any_value_to_max_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
concat_to_dpipe_sql,
|
||||
|
@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import (
|
|||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
|
||||
def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
|
||||
modifier = expression.expression
|
||||
modifier = modifier.name if modifier.is_string else self.sql(modifier)
|
||||
unit = expression.args.get("unit")
|
||||
|
@ -78,6 +79,7 @@ class SQLite(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
NVL2_SUPPORTED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -103,6 +105,7 @@ class SQLite(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.Create: transforms.preprocess([_transform_create]),
|
||||
|
|
|
@ -95,6 +95,9 @@ class Teradata(Dialect):
|
|||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
TokenType.DATABASE: lambda self: self.expression(
|
||||
exp.Use, this=self._parse_table(schema=False)
|
||||
),
|
||||
TokenType.REPLACE: lambda self: self._parse_create(),
|
||||
}
|
||||
|
||||
|
@ -165,6 +168,7 @@ class Teradata(Dialect):
|
|||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
|
||||
}
|
||||
|
||||
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
|
||||
|
|
|
@ -13,3 +13,6 @@ class Trino(Presto):
|
|||
|
||||
class Tokenizer(Presto.Tokenizer):
|
||||
HEX_STRINGS = [("X'", "'")]
|
||||
|
||||
class Parser(Presto.Parser):
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
|
|
@ -7,6 +7,7 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
any_value_to_max_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
parse_date_delta,
|
||||
|
@ -79,22 +80,23 @@ def _format_time_lambda(
|
|||
|
||||
|
||||
def _parse_format(args: t.List) -> exp.Expression:
|
||||
assert len(args) == 2
|
||||
this = seq_get(args, 0)
|
||||
fmt = seq_get(args, 1)
|
||||
culture = seq_get(args, 2)
|
||||
|
||||
fmt = args[1]
|
||||
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
|
||||
number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name))
|
||||
|
||||
if number_fmt:
|
||||
return exp.NumberToStr(this=args[0], format=fmt)
|
||||
return exp.NumberToStr(this=this, format=fmt, culture=culture)
|
||||
|
||||
return exp.TimeToStr(
|
||||
this=args[0],
|
||||
format=exp.Literal.string(
|
||||
if fmt:
|
||||
fmt = exp.Literal.string(
|
||||
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
|
||||
if len(fmt.name) == 1
|
||||
else format_time(fmt.name, TSQL.TIME_MAPPING)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return exp.TimeToStr(this=this, format=fmt, culture=culture)
|
||||
|
||||
|
||||
def _parse_eomonth(args: t.List) -> exp.Expression:
|
||||
|
@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
def generate_date_delta_with_unit_sql(
|
||||
self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
|
||||
self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
|
||||
) -> str:
|
||||
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
|
||||
return self.func(func, expression.text("unit"), expression.expression, expression.this)
|
||||
|
||||
|
||||
def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
|
||||
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
|
||||
fmt = (
|
||||
expression.args["format"]
|
||||
if isinstance(expression, exp.NumberToStr)
|
||||
|
@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
|
|||
)
|
||||
)
|
||||
)
|
||||
return self.func("FORMAT", expression.this, fmt)
|
||||
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
|
||||
|
||||
|
||||
def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
|
||||
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
this = expression.this
|
||||
|
@ -332,10 +334,12 @@ class TSQL(Dialect):
|
|||
"SQL_VARIANT": TokenType.VARIANT,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
|
||||
"UPDATE STATISTICS": TokenType.COMMAND,
|
||||
"VARCHAR(MAX)": TokenType.TEXT,
|
||||
"XML": TokenType.XML,
|
||||
"OUTPUT": TokenType.RETURNING,
|
||||
"SYSTEM_USER": TokenType.CURRENT_USER,
|
||||
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
|
@ -395,7 +399,9 @@ class TSQL(Dialect):
|
|||
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
"""
|
||||
T-SQL supports the syntax alias = expression in the SELECT's projection list,
|
||||
so we transform all parsed Selects to convert their EQ projections into Aliases.
|
||||
|
@ -458,43 +464,6 @@ class TSQL(Dialect):
|
|||
|
||||
return self._parse_as_command(self._prev)
|
||||
|
||||
def _parse_system_time(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
|
||||
return None
|
||||
|
||||
if self._match_text_seq("AS", "OF"):
|
||||
system_time = self.expression(
|
||||
exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
|
||||
)
|
||||
elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
|
||||
kind = self._prev.text
|
||||
this = self._parse_bitwise()
|
||||
self._match_texts(("TO", "AND"))
|
||||
expression = self._parse_bitwise()
|
||||
system_time = self.expression(
|
||||
exp.SystemTime, this=this, expression=expression, kind=kind
|
||||
)
|
||||
elif self._match_text_seq("CONTAINED", "IN"):
|
||||
args = self._parse_wrapped_csv(self._parse_bitwise)
|
||||
system_time = self.expression(
|
||||
exp.SystemTime,
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
kind="CONTAINED IN",
|
||||
)
|
||||
elif self._match(TokenType.ALL):
|
||||
system_time = self.expression(exp.SystemTime, kind="ALL")
|
||||
else:
|
||||
system_time = None
|
||||
self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
|
||||
|
||||
return system_time
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
table.set("system_time", self._parse_system_time())
|
||||
return table
|
||||
|
||||
def _parse_returns(self) -> exp.ReturnsProperty:
|
||||
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
|
||||
returns = super()._parse_returns()
|
||||
|
@ -589,14 +558,36 @@ class TSQL(Dialect):
|
|||
|
||||
return create
|
||||
|
||||
def _parse_if(self) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
if self._match_text_seq("OBJECT_ID"):
|
||||
self._parse_wrapped_csv(self._parse_string)
|
||||
if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
|
||||
return self._parse_drop(exists=True)
|
||||
self._retreat(index)
|
||||
|
||||
return super()._parse_if()
|
||||
|
||||
def _parse_unique(self) -> exp.UniqueColumnConstraint:
|
||||
return self.expression(
|
||||
exp.UniqueColumnConstraint,
|
||||
this=None
|
||||
if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
|
||||
else self._parse_schema(self._parse_id_var(any_token=False)),
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
LIMIT_IS_TOP = True
|
||||
QUERY_HINTS = False
|
||||
RETURNING_END = False
|
||||
NVL2_SUPPORTED = False
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BOOLEAN: "BIT",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
exp.DataType.Type.DATETIME: "DATETIME2",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
|
@ -607,6 +598,8 @@ class TSQL(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.AnyValue: any_value_to_max_sql,
|
||||
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
|
||||
exp.DateAdd: generate_date_delta_with_unit_sql,
|
||||
exp.DateDiff: generate_date_delta_with_unit_sql,
|
||||
exp.CurrentDate: rename_func("GETDATE"),
|
||||
|
@ -651,25 +644,44 @@ class TSQL(Dialect):
|
|||
|
||||
return sql
|
||||
|
||||
def create_sql(self, expression: exp.Create) -> str:
|
||||
expression = expression.copy()
|
||||
kind = self.sql(expression, "kind").upper()
|
||||
exists = expression.args.pop("exists", None)
|
||||
sql = super().create_sql(expression)
|
||||
|
||||
if exists:
|
||||
table = expression.find(exp.Table)
|
||||
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
|
||||
if kind == "SCHEMA":
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
|
||||
elif kind == "TABLE":
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
|
||||
elif kind == "INDEX":
|
||||
index = self.sql(exp.Literal.string(expression.this.text("this")))
|
||||
sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
|
||||
elif expression.args.get("replace"):
|
||||
sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
|
||||
|
||||
return sql
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
return f"{super().offset_sql(expression)} ROWS"
|
||||
|
||||
def systemtime_sql(self, expression: exp.SystemTime) -> str:
|
||||
kind = expression.args["kind"]
|
||||
if kind == "ALL":
|
||||
return "FOR SYSTEM_TIME ALL"
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name
|
||||
this = f"FOR {name}"
|
||||
expr = expression.expression
|
||||
kind = expression.text("kind")
|
||||
if kind in ("FROM", "BETWEEN"):
|
||||
args = expr.expressions
|
||||
sep = "TO" if kind == "FROM" else "AND"
|
||||
expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}"
|
||||
else:
|
||||
expr_sql = self.sql(expr)
|
||||
|
||||
start = self.sql(expression, "this")
|
||||
if kind == "AS OF":
|
||||
return f"FOR SYSTEM_TIME AS OF {start}"
|
||||
|
||||
end = self.sql(expression, "expression")
|
||||
if kind == "FROM":
|
||||
return f"FOR SYSTEM_TIME FROM {start} TO {end}"
|
||||
if kind == "BETWEEN":
|
||||
return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
|
||||
|
||||
return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
|
||||
expr_sql = f" {expr_sql}" if expr_sql else ""
|
||||
return f"{this} {kind}{expr_sql}"
|
||||
|
||||
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
|
||||
table = expression.args.get("table")
|
||||
|
@ -713,3 +725,16 @@ class TSQL(Dialect):
|
|||
identifier = f"#{identifier}"
|
||||
|
||||
return identifier
|
||||
|
||||
def constraint_sql(self, expression: exp.Constraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression, flat=True, sep=" ")
|
||||
return f"CONSTRAINT {this} {expressions}"
|
||||
|
||||
# https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
start = self.sql(expression, "start") or "1"
|
||||
increment = self.sql(expression, "increment") or "1"
|
||||
return f"IDENTITY({start}, {increment})"
|
||||
|
|
|
@ -1035,12 +1035,13 @@ class Clone(Expression):
|
|||
"this": True,
|
||||
"when": False,
|
||||
"kind": False,
|
||||
"shallow": False,
|
||||
"expression": False,
|
||||
}
|
||||
|
||||
|
||||
class Describe(Expression):
|
||||
arg_types = {"this": True, "kind": False}
|
||||
arg_types = {"this": True, "kind": False, "expressions": False}
|
||||
|
||||
|
||||
class Pragma(Expression):
|
||||
|
@ -1070,6 +1071,8 @@ class Show(Expression):
|
|||
"like": False,
|
||||
"where": False,
|
||||
"db": False,
|
||||
"scope": False,
|
||||
"scope_kind": False,
|
||||
"full": False,
|
||||
"mutex": False,
|
||||
"query": False,
|
||||
|
@ -1207,6 +1210,10 @@ class Comment(Expression):
|
|||
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
|
||||
|
||||
|
||||
class Comprehension(Expression):
|
||||
arg_types = {"this": True, "expression": True, "iterator": True, "condition": False}
|
||||
|
||||
|
||||
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
|
||||
class MergeTreeTTLAction(Expression):
|
||||
arg_types = {
|
||||
|
@ -1269,6 +1276,10 @@ class CheckColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class ClusteredColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
class CollateColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
@ -1316,6 +1327,14 @@ class InlineLengthColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
class NonClusteredColumnConstraint(ColumnConstraintKind):
|
||||
pass
|
||||
|
||||
|
||||
class NotForReplicationColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"allow_null": False}
|
||||
|
||||
|
@ -1345,6 +1364,12 @@ class PathColumnConstraint(ColumnConstraintKind):
|
|||
pass
|
||||
|
||||
|
||||
# computed column expression
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16
|
||||
class ComputedColumnConstraint(ColumnConstraintKind):
|
||||
arg_types = {"this": True, "persisted": False, "not_null": False}
|
||||
|
||||
|
||||
class Constraint(Expression):
|
||||
arg_types = {"this": True, "expressions": True}
|
||||
|
||||
|
@ -1489,6 +1514,15 @@ class Check(Expression):
|
|||
pass
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/connect-by
|
||||
class Connect(Expression):
|
||||
arg_types = {"start": False, "connect": True}
|
||||
|
||||
|
||||
class Prior(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class Directory(Expression):
|
||||
# https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html
|
||||
arg_types = {"this": True, "local": False, "row_format": False}
|
||||
|
@ -1578,6 +1612,7 @@ class Insert(DDL):
|
|||
"alternative": False,
|
||||
"where": False,
|
||||
"ignore": False,
|
||||
"by_name": False,
|
||||
}
|
||||
|
||||
def with_(
|
||||
|
@ -2045,8 +2080,12 @@ class NoPrimaryIndexProperty(Property):
|
|||
arg_types = {}
|
||||
|
||||
|
||||
class OnProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class OnCommitProperty(Property):
|
||||
arg_type = {"delete": False}
|
||||
arg_types = {"delete": False}
|
||||
|
||||
|
||||
class PartitionedByProperty(Property):
|
||||
|
@ -2282,6 +2321,16 @@ class Subqueryable(Unionable):
|
|||
def named_selects(self) -> t.List[str]:
|
||||
raise NotImplementedError("Subqueryable objects must implement `named_selects`")
|
||||
|
||||
def select(
|
||||
self,
|
||||
*expressions: t.Optional[ExpOrStr],
|
||||
append: bool = True,
|
||||
dialect: DialectType = None,
|
||||
copy: bool = True,
|
||||
**opts,
|
||||
) -> Subqueryable:
|
||||
raise NotImplementedError("Subqueryable objects must implement `select`")
|
||||
|
||||
def with_(
|
||||
self,
|
||||
alias: ExpOrStr,
|
||||
|
@ -2323,6 +2372,7 @@ QUERY_MODIFIERS = {
|
|||
"match": False,
|
||||
"laterals": False,
|
||||
"joins": False,
|
||||
"connect": False,
|
||||
"pivots": False,
|
||||
"where": False,
|
||||
"group": False,
|
||||
|
@ -2363,6 +2413,7 @@ class Table(Expression):
|
|||
"pivots": False,
|
||||
"hints": False,
|
||||
"system_time": False,
|
||||
"version": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2403,21 +2454,13 @@ class Table(Expression):
|
|||
return parts
|
||||
|
||||
|
||||
# See the TSQL "Querying data in a system-versioned temporal table" page
|
||||
class SystemTime(Expression):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
"expression": False,
|
||||
"kind": True,
|
||||
}
|
||||
|
||||
|
||||
class Union(Subqueryable):
|
||||
arg_types = {
|
||||
"with": False,
|
||||
"this": True,
|
||||
"expression": True,
|
||||
"distinct": False,
|
||||
"by_name": False,
|
||||
**QUERY_MODIFIERS,
|
||||
}
|
||||
|
||||
|
@ -2529,6 +2572,7 @@ class Update(Expression):
|
|||
"from": False,
|
||||
"where": False,
|
||||
"returning": False,
|
||||
"order": False,
|
||||
"limit": False,
|
||||
}
|
||||
|
||||
|
@ -2545,6 +2589,20 @@ class Var(Expression):
|
|||
pass
|
||||
|
||||
|
||||
class Version(Expression):
|
||||
"""
|
||||
Time travel, iceberg, bigquery etc
|
||||
https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots
|
||||
https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html
|
||||
https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of
|
||||
https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16
|
||||
this is either TIMESTAMP or VERSION
|
||||
kind is ("AS OF", "BETWEEN")
|
||||
"""
|
||||
|
||||
arg_types = {"this": True, "kind": True, "expression": False}
|
||||
|
||||
|
||||
class Schema(Expression):
|
||||
arg_types = {"this": False, "expressions": False}
|
||||
|
||||
|
@ -3263,6 +3321,23 @@ class Subquery(DerivedTable, Unionable):
|
|||
expression = expression.this
|
||||
return expression
|
||||
|
||||
def unwrap(self) -> Subquery:
|
||||
expression = self
|
||||
while expression.same_parent and expression.is_wrapper:
|
||||
expression = t.cast(Subquery, expression.parent)
|
||||
return expression
|
||||
|
||||
@property
|
||||
def is_wrapper(self) -> bool:
|
||||
"""
|
||||
Whether this Subquery acts as a simple wrapper around another expression.
|
||||
|
||||
SELECT * FROM (((SELECT * FROM t)))
|
||||
^
|
||||
This corresponds to a "wrapper" Subquery node
|
||||
"""
|
||||
return all(v is None for k, v in self.args.items() if k != "this")
|
||||
|
||||
@property
|
||||
def is_star(self) -> bool:
|
||||
return self.this.is_star
|
||||
|
@ -3313,7 +3388,7 @@ class Pivot(Expression):
|
|||
}
|
||||
|
||||
|
||||
class Window(Expression):
|
||||
class Window(Condition):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"partition_by": False,
|
||||
|
@ -3375,7 +3450,7 @@ class Boolean(Condition):
|
|||
pass
|
||||
|
||||
|
||||
class DataTypeSize(Expression):
|
||||
class DataTypeParam(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
|
@ -3386,6 +3461,7 @@ class DataType(Expression):
|
|||
"nested": False,
|
||||
"values": False,
|
||||
"prefix": False,
|
||||
"kind": False,
|
||||
}
|
||||
|
||||
class Type(AutoName):
|
||||
|
@ -3432,6 +3508,7 @@ class DataType(Expression):
|
|||
LOWCARDINALITY = auto()
|
||||
MAP = auto()
|
||||
MEDIUMBLOB = auto()
|
||||
MEDIUMINT = auto()
|
||||
MEDIUMTEXT = auto()
|
||||
MONEY = auto()
|
||||
NCHAR = auto()
|
||||
|
@ -3475,6 +3552,7 @@ class DataType(Expression):
|
|||
VARCHAR = auto()
|
||||
VARIANT = auto()
|
||||
XML = auto()
|
||||
YEAR = auto()
|
||||
|
||||
TEXT_TYPES = {
|
||||
Type.CHAR,
|
||||
|
@ -3498,7 +3576,10 @@ class DataType(Expression):
|
|||
Type.DOUBLE,
|
||||
}
|
||||
|
||||
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
|
||||
NUMERIC_TYPES = {
|
||||
*INTEGER_TYPES,
|
||||
*FLOAT_TYPES,
|
||||
}
|
||||
|
||||
TEMPORAL_TYPES = {
|
||||
Type.TIME,
|
||||
|
@ -3511,23 +3592,39 @@ class DataType(Expression):
|
|||
Type.DATETIME64,
|
||||
}
|
||||
|
||||
META_TYPES = {"UNKNOWN", "NULL"}
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
|
||||
cls,
|
||||
dtype: str | DataType | DataType.Type,
|
||||
dialect: DialectType = None,
|
||||
udt: bool = False,
|
||||
**kwargs,
|
||||
) -> DataType:
|
||||
"""
|
||||
Constructs a DataType object.
|
||||
|
||||
Args:
|
||||
dtype: the data type of interest.
|
||||
dialect: the dialect to use for parsing `dtype`, in case it's a string.
|
||||
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
|
||||
DataType, thus creating a user-defined type.
|
||||
kawrgs: additional arguments to pass in the constructor of DataType.
|
||||
|
||||
Returns:
|
||||
The constructed DataType object.
|
||||
"""
|
||||
from sqlglot import parse_one
|
||||
|
||||
if isinstance(dtype, str):
|
||||
upper = dtype.upper()
|
||||
if upper in DataType.META_TYPES:
|
||||
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
|
||||
else:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
if dtype.upper() == "UNKNOWN":
|
||||
return DataType(this=DataType.Type.UNKNOWN, **kwargs)
|
||||
|
||||
if data_type_exp is None:
|
||||
raise ValueError(f"Unparsable data type value: {dtype}")
|
||||
try:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
except ParseError:
|
||||
if udt:
|
||||
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
|
||||
raise
|
||||
elif isinstance(dtype, DataType.Type):
|
||||
data_type_exp = DataType(this=dtype)
|
||||
elif isinstance(dtype, DataType):
|
||||
|
@ -3538,7 +3635,31 @@ class DataType(Expression):
|
|||
return DataType(**{**data_type_exp.args, **kwargs})
|
||||
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
return any(self.this == DataType.build(dtype).this for dtype in dtypes)
|
||||
"""
|
||||
Checks whether this DataType matches one of the provided data types. Nested types or precision
|
||||
will be compared using "structural equivalence" semantics, so e.g. array<int> != array<float>.
|
||||
|
||||
Args:
|
||||
dtypes: the data types to compare this DataType to.
|
||||
|
||||
Returns:
|
||||
True, if and only if there is a type in `dtypes` which is equal to this DataType.
|
||||
"""
|
||||
for dtype in dtypes:
|
||||
other = DataType.build(dtype, udt=True)
|
||||
|
||||
if (
|
||||
other.expressions
|
||||
or self.this == DataType.Type.USERDEFINED
|
||||
or other.this == DataType.Type.USERDEFINED
|
||||
):
|
||||
matches = self == other
|
||||
else:
|
||||
matches = self.this == other.this
|
||||
|
||||
if matches:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
|
@ -3546,6 +3667,11 @@ class PseudoType(Expression):
|
|||
pass
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-oid.html
|
||||
class ObjectIdentifier(Expression):
|
||||
pass
|
||||
|
||||
|
||||
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
|
||||
class SubqueryPredicate(Predicate):
|
||||
pass
|
||||
|
@ -4005,6 +4131,7 @@ class ArrayAny(Func):
|
|||
|
||||
|
||||
class ArrayConcat(Func):
|
||||
_sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"]
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
@ -4047,7 +4174,15 @@ class Avg(AggFunc):
|
|||
|
||||
|
||||
class AnyValue(AggFunc):
|
||||
arg_types = {"this": True, "having": False, "max": False}
|
||||
arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False}
|
||||
|
||||
|
||||
class First(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
|
||||
|
||||
class Last(Func):
|
||||
arg_types = {"this": True, "ignore_nulls": False}
|
||||
|
||||
|
||||
class Case(Func):
|
||||
|
@ -4086,18 +4221,29 @@ class Cast(Func):
|
|||
return self.name
|
||||
|
||||
def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
|
||||
"""
|
||||
Checks whether this Cast's DataType matches one of the provided data types. Nested types
|
||||
like arrays or structs will be compared using "structural equivalence" semantics, so e.g.
|
||||
array<int> != array<float>.
|
||||
|
||||
Args:
|
||||
dtypes: the data types to compare this Cast's DataType to.
|
||||
|
||||
Returns:
|
||||
True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType.
|
||||
"""
|
||||
return self.to.is_type(*dtypes)
|
||||
|
||||
|
||||
class CastToStrType(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Collate(Binary):
|
||||
class TryCast(Cast):
|
||||
pass
|
||||
|
||||
|
||||
class TryCast(Cast):
|
||||
class CastToStrType(Func):
|
||||
arg_types = {"this": True, "to": True}
|
||||
|
||||
|
||||
class Collate(Binary):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -4310,7 +4456,7 @@ class Greatest(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class GroupConcat(Func):
|
||||
class GroupConcat(AggFunc):
|
||||
arg_types = {"this": True, "separator": False}
|
||||
|
||||
|
||||
|
@ -4648,8 +4794,19 @@ class StrToUnix(Func):
|
|||
arg_types = {"this": False, "format": False}
|
||||
|
||||
|
||||
# https://prestodb.io/docs/current/functions/string.html
|
||||
# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map
|
||||
class StrToMap(Func):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"pair_delim": False,
|
||||
"key_value_delim": False,
|
||||
"duplicate_resolution_callback": False,
|
||||
}
|
||||
|
||||
|
||||
class NumberToStr(Func):
|
||||
arg_types = {"this": True, "format": True}
|
||||
arg_types = {"this": True, "format": True, "culture": False}
|
||||
|
||||
|
||||
class FromBase(Func):
|
||||
|
@ -4665,6 +4822,13 @@ class StructExtract(Func):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/insert
|
||||
class Stuff(Func):
|
||||
_sql_names = ["STUFF", "INSERT"]
|
||||
arg_types = {"this": True, "start": True, "length": True, "expression": True}
|
||||
|
||||
|
||||
class Sum(AggFunc):
|
||||
pass
|
||||
|
||||
|
@ -4686,7 +4850,7 @@ class StddevSamp(AggFunc):
|
|||
|
||||
|
||||
class TimeToStr(Func):
|
||||
arg_types = {"this": True, "format": True}
|
||||
arg_types = {"this": True, "format": True, "culture": False}
|
||||
|
||||
|
||||
class TimeToTimeStr(Func):
|
||||
|
@ -5724,9 +5888,9 @@ def table_(
|
|||
The new Table instance.
|
||||
"""
|
||||
return Table(
|
||||
this=to_identifier(table, quoted=quoted),
|
||||
db=to_identifier(db, quoted=quoted),
|
||||
catalog=to_identifier(catalog, quoted=quoted),
|
||||
this=to_identifier(table, quoted=quoted) if table else None,
|
||||
db=to_identifier(db, quoted=quoted) if db else None,
|
||||
catalog=to_identifier(catalog, quoted=quoted) if catalog else None,
|
||||
alias=TableAlias(this=to_identifier(alias)) if alias else None,
|
||||
)
|
||||
|
||||
|
@ -5844,8 +6008,8 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
|
|||
return Array(expressions=[convert(v, copy=copy) for v in value])
|
||||
if isinstance(value, dict):
|
||||
return Map(
|
||||
keys=[convert(k, copy=copy) for k in value],
|
||||
values=[convert(v, copy=copy) for v in value.values()],
|
||||
keys=Array(expressions=[convert(k, copy=copy) for k in value]),
|
||||
values=Array(expressions=[convert(v, copy=copy) for v in value.values()]),
|
||||
)
|
||||
raise ValueError(f"Cannot convert {value}")
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot import exp
|
|||
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
|
||||
from sqlglot.helper import apply_index_offset, csv, seq_get
|
||||
from sqlglot.time import format_time
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -61,6 +61,7 @@ class Generator:
|
|||
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
|
||||
exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}",
|
||||
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
|
||||
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
|
||||
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
|
||||
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
|
||||
|
@ -78,7 +79,10 @@ class Generator:
|
|||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
exp.MaterializedProperty: lambda self, e: "MATERIALIZED",
|
||||
exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX",
|
||||
exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})",
|
||||
exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION",
|
||||
exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS",
|
||||
exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}",
|
||||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
|
@ -171,6 +175,9 @@ class Generator:
|
|||
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
|
||||
TZ_TO_WITH_TIME_ZONE = False
|
||||
|
||||
# Whether or not the NVL2 function is supported
|
||||
NVL2_SUPPORTED = True
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
|
@ -179,6 +186,9 @@ class Generator:
|
|||
# SELECT * VALUES into SELECT UNION
|
||||
VALUES_AS_TABLE = True
|
||||
|
||||
# Whether or not the word COLUMN is included when adding a column with ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -245,6 +255,7 @@ class Generator:
|
|||
exp.MaterializedProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME,
|
||||
exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.OnProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION,
|
||||
exp.Order: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
|
||||
|
@ -317,8 +328,7 @@ class Generator:
|
|||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
STRING_ESCAPE = "'"
|
||||
IDENTIFIER_ESCAPE = '"'
|
||||
TOKENIZER_CLASS = Tokenizer
|
||||
|
||||
# Delimiters for bit, hex, byte and raw literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
|
@ -379,8 +389,10 @@ class Generator:
|
|||
)
|
||||
|
||||
self.unsupported_messages: t.List[str] = []
|
||||
self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
|
||||
self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
|
||||
self._escaped_quote_end: str = self.TOKENIZER_CLASS.STRING_ESCAPES[0] + self.QUOTE_END
|
||||
self._escaped_identifier_end: str = (
|
||||
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
|
||||
)
|
||||
self._cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
||||
def generate(
|
||||
|
@ -626,6 +638,16 @@ class Generator:
|
|||
kind_sql = self.sql(expression, "kind").strip()
|
||||
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
|
||||
|
||||
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if expression.args.get("not_null"):
|
||||
persisted = " PERSISTED NOT NULL"
|
||||
elif expression.args.get("persisted"):
|
||||
persisted = " PERSISTED"
|
||||
else:
|
||||
persisted = ""
|
||||
return f"AS {this}{persisted}"
|
||||
|
||||
def autoincrementcolumnconstraint_sql(self, _) -> str:
|
||||
return self.token_sql(TokenType.AUTO_INCREMENT)
|
||||
|
||||
|
@ -642,8 +664,8 @@ class Generator:
|
|||
) -> str:
|
||||
this = ""
|
||||
if expression.this is not None:
|
||||
on_null = "ON NULL " if expression.args.get("on_null") else ""
|
||||
this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}"
|
||||
on_null = " ON NULL" if expression.args.get("on_null") else ""
|
||||
this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}"
|
||||
|
||||
start = expression.args.get("start")
|
||||
start = f"START WITH {start}" if start else ""
|
||||
|
@ -668,7 +690,7 @@ class Generator:
|
|||
expr = self.sql(expression, "expression")
|
||||
expr = f"({expr})" if expr else "IDENTITY"
|
||||
|
||||
return f"GENERATED{this}AS {expr}{sequence_opts}"
|
||||
return f"GENERATED{this} AS {expr}{sequence_opts}"
|
||||
|
||||
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
|
||||
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
|
||||
|
@ -774,14 +796,16 @@ class Generator:
|
|||
|
||||
def clone_sql(self, expression: exp.Clone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
shallow = "SHALLOW " if expression.args.get("shallow") else ""
|
||||
this = f"{shallow}CLONE {this}"
|
||||
when = self.sql(expression, "when")
|
||||
|
||||
if when:
|
||||
kind = self.sql(expression, "kind")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"CLONE {this} {when} ({kind} => {expr})"
|
||||
return f"{this} {when} ({kind} => {expr})"
|
||||
|
||||
return f"CLONE {this}"
|
||||
return this
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
|
@ -830,7 +854,7 @@ class Generator:
|
|||
string = self.escape_str(expression.this.replace("\\", "\\\\"))
|
||||
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
|
||||
|
||||
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
|
||||
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
specifier = self.sql(expression, "expression")
|
||||
specifier = f" {specifier}" if specifier else ""
|
||||
|
@ -839,11 +863,14 @@ class Generator:
|
|||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
|
||||
type_sql = self.sql(expression, "kind")
|
||||
else:
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
|
||||
nested = ""
|
||||
interior = self.expressions(expression, flat=True)
|
||||
|
@ -943,9 +970,9 @@ class Generator:
|
|||
name = self.sql(expression, "this")
|
||||
name = f"{name} " if name else ""
|
||||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table} " if table else ""
|
||||
table = f"{self.INDEX_ON} {table}" if table else ""
|
||||
using = self.sql(expression, "using")
|
||||
using = f"USING {using} " if using else ""
|
||||
using = f" USING {using} " if using else ""
|
||||
index = "INDEX " if not table else ""
|
||||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
|
@ -1171,6 +1198,7 @@ class Generator:
|
|||
where = f"{self.sep()}REPLACE WHERE {where}" if where else ""
|
||||
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
|
||||
conflict = self.sql(expression, "conflict")
|
||||
by_name = " BY NAME" if expression.args.get("by_name") else ""
|
||||
returning = self.sql(expression, "returning")
|
||||
|
||||
if self.RETURNING_END:
|
||||
|
@ -1178,7 +1206,7 @@ class Generator:
|
|||
else:
|
||||
expression_sql = f"{returning}{expression_sql}{conflict}"
|
||||
|
||||
sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}"
|
||||
sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def intersect_sql(self, expression: exp.Intersect) -> str:
|
||||
|
@ -1196,6 +1224,9 @@ class Generator:
|
|||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
def onconflict_sql(self, expression: exp.OnConflict) -> str:
|
||||
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
|
||||
constraint = self.sql(expression, "constraint")
|
||||
|
@ -1248,6 +1279,8 @@ class Generator:
|
|||
if part
|
||||
)
|
||||
|
||||
version = self.sql(expression, "version")
|
||||
version = f" {version}" if version else ""
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f"{sep}{alias}" if alias else ""
|
||||
hints = self.expressions(expression, key="hints", sep=" ")
|
||||
|
@ -1256,10 +1289,8 @@ class Generator:
|
|||
pivots = f" {pivots}" if pivots else ""
|
||||
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
system_time = expression.args.get("system_time")
|
||||
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
|
||||
|
||||
return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
|
@ -1314,6 +1345,12 @@ class Generator:
|
|||
nulls = ""
|
||||
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
|
||||
|
||||
def version_sql(self, expression: exp.Version) -> str:
|
||||
this = f"FOR {expression.name}"
|
||||
kind = expression.text("kind")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"{this} {kind} {expr}"
|
||||
|
||||
def tuple_sql(self, expression: exp.Tuple) -> str:
|
||||
return f"({self.expressions(expression, flat=True)})"
|
||||
|
||||
|
@ -1323,12 +1360,13 @@ class Generator:
|
|||
from_sql = self.sql(expression, "from")
|
||||
where_sql = self.sql(expression, "where")
|
||||
returning = self.sql(expression, "returning")
|
||||
order = self.sql(expression, "order")
|
||||
limit = self.sql(expression, "limit")
|
||||
if self.RETURNING_END:
|
||||
expression_sql = f"{from_sql}{where_sql}{returning}{limit}"
|
||||
expression_sql = f"{from_sql}{where_sql}{returning}"
|
||||
else:
|
||||
expression_sql = f"{returning}{from_sql}{where_sql}{limit}"
|
||||
sql = f"UPDATE {this} SET {set_sql}{expression_sql}"
|
||||
expression_sql = f"{returning}{from_sql}{where_sql}"
|
||||
sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}"
|
||||
return self.prepend_ctes(expression, sql)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
|
@ -1425,6 +1463,16 @@ class Generator:
|
|||
this = self.indent(self.sql(expression, "this"))
|
||||
return f"{self.seg('HAVING')}{self.sep()}{this}"
|
||||
|
||||
def connect_sql(self, expression: exp.Connect) -> str:
|
||||
start = self.sql(expression, "start")
|
||||
start = self.seg(f"START WITH {start}") if start else ""
|
||||
connect = self.sql(expression, "connect")
|
||||
connect = self.seg(f"CONNECT BY {connect}")
|
||||
return start + connect
|
||||
|
||||
def prior_sql(self, expression: exp.Prior) -> str:
|
||||
return f"PRIOR {self.sql(expression, 'this')}"
|
||||
|
||||
def join_sql(self, expression: exp.Join) -> str:
|
||||
op_sql = " ".join(
|
||||
op
|
||||
|
@ -1667,6 +1715,7 @@ class Generator:
|
|||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(join) for join in expression.args.get("joins") or []],
|
||||
self.sql(expression, "connect"),
|
||||
self.sql(expression, "match"),
|
||||
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
|
||||
self.sql(expression, "where"),
|
||||
|
@ -1801,7 +1850,8 @@ class Generator:
|
|||
def union_op(self, expression: exp.Union) -> str:
|
||||
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
|
||||
kind = kind if expression.args.get("distinct") else " ALL"
|
||||
return f"UNION{kind}"
|
||||
by_name = " BY NAME" if expression.args.get("by_name") else ""
|
||||
return f"UNION{kind}{by_name}"
|
||||
|
||||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
|
@ -2224,7 +2274,14 @@ class Generator:
|
|||
actions = expression.args["actions"]
|
||||
|
||||
if isinstance(actions[0], exp.ColumnDef):
|
||||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
|
||||
if self.ALTER_TABLE_ADD_COLUMN_KEYWORD:
|
||||
actions = self.expressions(
|
||||
expression,
|
||||
key="actions",
|
||||
prefix="ADD COLUMN ",
|
||||
)
|
||||
else:
|
||||
actions = f"ADD {self.expressions(expression, key='actions')}"
|
||||
elif isinstance(actions[0], exp.Schema):
|
||||
actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
|
||||
elif isinstance(actions[0], exp.Delete):
|
||||
|
@ -2525,10 +2582,21 @@ class Generator:
|
|||
return f"WHEN {matched}{source}{condition} THEN {then}"
|
||||
|
||||
def merge_sql(self, expression: exp.Merge) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
table = expression.this
|
||||
table_alias = ""
|
||||
|
||||
hints = table.args.get("hints")
|
||||
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
|
||||
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
|
||||
table = table.copy()
|
||||
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
|
||||
|
||||
this = self.sql(table)
|
||||
using = f"USING {self.sql(expression, 'using')}"
|
||||
on = f"ON {self.sql(expression, 'on')}"
|
||||
return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"
|
||||
expressions = self.expressions(expression, sep=" ")
|
||||
|
||||
return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
|
||||
|
||||
def tochar_sql(self, expression: exp.ToChar) -> str:
|
||||
if expression.args.get("format"):
|
||||
|
@ -2631,6 +2699,29 @@ class Generator:
|
|||
options = f" {options}" if options else ""
|
||||
return f"{kind}{this}{type_}{schema}{options}"
|
||||
|
||||
def nvl2_sql(self, expression: exp.Nvl2) -> str:
|
||||
if self.NVL2_SUPPORTED:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
case = exp.Case().when(
|
||||
expression.this.is_(exp.null()).not_(copy=False),
|
||||
expression.args["true"].copy(),
|
||||
copy=False,
|
||||
)
|
||||
else_cond = expression.args.get("false")
|
||||
if else_cond:
|
||||
case.else_(else_cond.copy(), copy=False)
|
||||
|
||||
return self.sql(case)
|
||||
|
||||
def comprehension_sql(self, expression: exp.Comprehension) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
expr = self.sql(expression, "expression")
|
||||
iterator = self.sql(expression, "iterator")
|
||||
condition = self.sql(expression, "condition")
|
||||
condition = f" IF {condition}" if condition else ""
|
||||
return f"{this} FOR {expr} IN {iterator}{condition}"
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -33,6 +33,15 @@ class AutoName(Enum):
|
|||
return name
|
||||
|
||||
|
||||
class classproperty(property):
|
||||
"""
|
||||
Similar to a normal property but works for class methods
|
||||
"""
|
||||
|
||||
def __get__(self, obj: t.Any, owner: t.Any = None) -> t.Any:
|
||||
return classmethod(self.fget).__get__(None, owner)() # type: ignore
|
||||
|
||||
|
||||
def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
|
||||
"""Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
|
||||
try:
|
||||
|
@ -137,9 +146,9 @@ def subclasses(
|
|||
|
||||
def apply_index_offset(
|
||||
this: exp.Expression,
|
||||
expressions: t.List[t.Optional[E]],
|
||||
expressions: t.List[E],
|
||||
offset: int,
|
||||
) -> t.List[t.Optional[E]]:
|
||||
) -> t.List[E]:
|
||||
"""
|
||||
Applies an offset to a given integer literal expression.
|
||||
|
||||
|
@ -170,15 +179,14 @@ def apply_index_offset(
|
|||
):
|
||||
return expressions
|
||||
|
||||
if expression:
|
||||
if not expression.type:
|
||||
annotate_types(expression)
|
||||
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression = simplify(
|
||||
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
|
||||
)
|
||||
return [expression]
|
||||
if not expression.type:
|
||||
annotate_types(expression)
|
||||
if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES:
|
||||
logger.warning("Applying array index offset (%s)", offset)
|
||||
expression = simplify(
|
||||
exp.Add(this=expression.copy(), expression=exp.Literal.number(offset))
|
||||
)
|
||||
return [expression]
|
||||
|
||||
return expressions
|
||||
|
||||
|
|
|
@ -1,2 +1,9 @@
|
|||
from sqlglot.optimizer.optimizer import RULES, optimize
|
||||
from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope
|
||||
from sqlglot.optimizer.scope import (
|
||||
Scope,
|
||||
build_scope,
|
||||
find_all_in_scope,
|
||||
find_in_scope,
|
||||
traverse_scope,
|
||||
walk_in_scope,
|
||||
)
|
||||
|
|
|
@ -203,10 +203,15 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
for expr_type in expressions
|
||||
},
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
|
||||
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
|
||||
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
|
@ -220,6 +225,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
}
|
||||
|
||||
NESTED_TYPES = {
|
||||
exp.DataType.Type.ARRAY,
|
||||
}
|
||||
|
||||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
|
@ -299,19 +308,22 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType.Type:
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if isinstance(type1, exp.DataType):
|
||||
type1 = type1.this
|
||||
if isinstance(type2, exp.DataType):
|
||||
type2 = type2.this
|
||||
) -> exp.DataType | exp.DataType.Type:
|
||||
type1_value = type1.this if isinstance(type1, exp.DataType) else type1
|
||||
type2_value = type2.this if isinstance(type2, exp.DataType) else type2
|
||||
|
||||
if exp.DataType.Type.NULL in (type1, type2):
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if exp.DataType.Type.NULL in (type1_value, type2_value):
|
||||
return exp.DataType.Type.NULL
|
||||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
|
||||
if type1_value in self.NESTED_TYPES:
|
||||
return type1
|
||||
if type2_value in self.NESTED_TYPES:
|
||||
return type2
|
||||
|
||||
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore
|
||||
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
|
@ -368,7 +380,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
return self._annotate_args(expression)
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
|
||||
def _annotate_by_args(
|
||||
self, expression: E, *args: str, promote: bool = False, array: bool = False
|
||||
) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
expressions: t.List[exp.Expression] = []
|
||||
|
@ -388,4 +402,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
elif expression.type.this in exp.DataType.FLOAT_TYPES:
|
||||
expression.type = exp.DataType.Type.DOUBLE
|
||||
|
||||
if array:
|
||||
expression.type = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
|
@ -142,13 +142,14 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
|
|||
if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
|
||||
return None
|
||||
|
||||
parent = scope.expression.parent
|
||||
# Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers
|
||||
to_replace = scope.expression.parent.unwrap()
|
||||
name, cte = _new_cte(scope, existing_ctes, taken)
|
||||
table = exp.alias_(exp.table_(name), alias=to_replace.alias or name)
|
||||
table.set("joins", to_replace.args.get("joins"))
|
||||
|
||||
table = exp.alias_(exp.table_(name), alias=parent.alias or name)
|
||||
table.set("joins", parent.args.get("joins"))
|
||||
to_replace.replace(table)
|
||||
|
||||
parent.replace(table)
|
||||
return cte
|
||||
|
||||
|
||||
|
|
|
@ -72,8 +72,13 @@ def normalize(expression):
|
|||
if not any(join.args.get(k) for k in JOIN_ATTRS):
|
||||
join.set("kind", "CROSS")
|
||||
|
||||
if join.kind != "CROSS":
|
||||
if join.kind == "CROSS":
|
||||
join.set("on", None)
|
||||
else:
|
||||
join.set("kind", None)
|
||||
|
||||
if not join.args.get("on") and not join.args.get("using"):
|
||||
join.set("on", exp.true())
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from sqlglot import exp
|
||||
from sqlglot.optimizer.normalize import normalized
|
||||
from sqlglot.optimizer.scope import build_scope
|
||||
from sqlglot.optimizer.scope import build_scope, find_in_scope
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
|
@ -81,7 +81,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
|
|||
break
|
||||
if isinstance(node, exp.Select):
|
||||
predicate.replace(exp.true())
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
inner_predicate = replace_aliases(node, predicate)
|
||||
if find_in_scope(inner_predicate, exp.AggFunc):
|
||||
node.having(inner_predicate, copy=False)
|
||||
else:
|
||||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def pushdown_dnf(predicates, scope, scope_ref_count):
|
||||
|
@ -142,7 +146,11 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
if isinstance(node, exp.Join):
|
||||
node.on(predicate, copy=False)
|
||||
elif isinstance(node, exp.Select):
|
||||
node.where(replace_aliases(node, predicate), copy=False)
|
||||
inner_predicate = replace_aliases(node, predicate)
|
||||
if find_in_scope(inner_predicate, exp.AggFunc):
|
||||
node.having(inner_predicate, copy=False)
|
||||
else:
|
||||
node.where(inner_predicate, copy=False)
|
||||
|
||||
|
||||
def nodes_for_predicate(predicate, sources, scope_ref_count):
|
||||
|
|
|
@ -6,7 +6,7 @@ from enum import Enum, auto
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import OptimizeError
|
||||
from sqlglot.helper import find_new_name
|
||||
from sqlglot.helper import ensure_collection, find_new_name
|
||||
|
||||
logger = logging.getLogger("sqlglot")
|
||||
|
||||
|
@ -141,38 +141,10 @@ class Scope:
|
|||
return walk_in_scope(self.expression, bfs=bfs)
|
||||
|
||||
def find(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(self.find_all(*expression_types, bfs=bfs), None)
|
||||
return find_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
||||
def find_all(self, *expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression_types (type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in self.walk(bfs=bfs):
|
||||
if isinstance(expression, expression_types):
|
||||
yield expression
|
||||
return find_all_in_scope(self.expression, expression_types, bfs=bfs)
|
||||
|
||||
def replace(self, old, new):
|
||||
"""
|
||||
|
@ -800,3 +772,41 @@ def walk_in_scope(expression, bfs=True):
|
|||
for key in ("joins", "laterals", "pivots"):
|
||||
for arg in node.args.get(key) or []:
|
||||
yield from walk_in_scope(arg, bfs=bfs)
|
||||
|
||||
|
||||
def find_all_in_scope(expression, expression_types, bfs=True):
|
||||
"""
|
||||
Returns a generator object which visits all nodes in this scope and only yields those that
|
||||
match at least one of the specified expression types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
expression_types (tuple[type]|type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Yields:
|
||||
exp.Expression: nodes
|
||||
"""
|
||||
for expression, *_ in walk_in_scope(expression, bfs=bfs):
|
||||
if isinstance(expression, tuple(ensure_collection(expression_types))):
|
||||
yield expression
|
||||
|
||||
|
||||
def find_in_scope(expression, expression_types, bfs=True):
|
||||
"""
|
||||
Returns the first node in this scope which matches at least one of the specified types.
|
||||
|
||||
This does NOT traverse into subscopes.
|
||||
|
||||
Args:
|
||||
expression (exp.Expression):
|
||||
expression_types (tuple[type]|type): the expression type(s) to match.
|
||||
bfs (bool): True to use breadth-first search, False to use depth-first.
|
||||
|
||||
Returns:
|
||||
exp.Expression: the node which matches the criteria or None if no node matching
|
||||
the criteria was found.
|
||||
"""
|
||||
return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
|
||||
|
|
|
@ -69,10 +69,10 @@ def simplify(expression):
|
|||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
node = remove_compliments(node, root)
|
||||
node = simplify_coalesce(node)
|
||||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_coalesce(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -350,7 +350,8 @@ def absorb_and_eliminate(expression, root=True):
|
|||
def simplify_literals(expression, root=True):
|
||||
if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
|
||||
return _flat_simplify(expression, _simplify_binary, root)
|
||||
elif isinstance(expression, exp.Neg):
|
||||
|
||||
if isinstance(expression, exp.Neg):
|
||||
this = expression.this
|
||||
if this.is_number:
|
||||
value = this.name
|
||||
|
@ -430,13 +431,14 @@ def simplify_parens(expression):
|
|||
|
||||
if not isinstance(this, exp.Select) and (
|
||||
not isinstance(parent, (exp.Condition, exp.Binary))
|
||||
or isinstance(this, exp.Predicate)
|
||||
or isinstance(parent, exp.Paren)
|
||||
or not isinstance(this, exp.Binary)
|
||||
or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
|
||||
or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
|
||||
or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
|
||||
):
|
||||
return expression.this
|
||||
return this
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -488,18 +490,20 @@ def simplify_coalesce(expression):
|
|||
coalesce = coalesce if coalesce.expressions else coalesce.this
|
||||
|
||||
# This expression is more complex than when we started, but it will get simplified further
|
||||
return exp.or_(
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()).not_(copy=False),
|
||||
expression.copy(),
|
||||
return exp.paren(
|
||||
exp.or_(
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()).not_(copy=False),
|
||||
expression.copy(),
|
||||
copy=False,
|
||||
),
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()),
|
||||
type(expression)(this=arg.copy(), expression=other.copy()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
),
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()),
|
||||
type(expression)(this=arg.copy(), expression=other.copy()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -642,7 +646,7 @@ def _flat_simplify(expression, simplifier, root=True):
|
|||
for b in queue:
|
||||
result = simplifier(expression, a, b)
|
||||
|
||||
if result:
|
||||
if result and result is not expression:
|
||||
queue.remove(b)
|
||||
queue.appendleft(result)
|
||||
break
|
||||
|
|
|
@ -136,6 +136,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.UINT128,
|
||||
TokenType.INT256,
|
||||
TokenType.UINT256,
|
||||
TokenType.MEDIUMINT,
|
||||
TokenType.FIXEDSTRING,
|
||||
TokenType.FLOAT,
|
||||
TokenType.DOUBLE,
|
||||
|
@ -186,6 +187,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.SMALLSERIAL,
|
||||
TokenType.BIGSERIAL,
|
||||
TokenType.XML,
|
||||
TokenType.YEAR,
|
||||
TokenType.UNIQUEIDENTIFIER,
|
||||
TokenType.USERDEFINED,
|
||||
TokenType.MONEY,
|
||||
|
@ -194,9 +196,12 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.IMAGE,
|
||||
TokenType.VARIANT,
|
||||
TokenType.OBJECT,
|
||||
TokenType.OBJECT_IDENTIFIER,
|
||||
TokenType.INET,
|
||||
TokenType.IPADDRESS,
|
||||
TokenType.IPPREFIX,
|
||||
TokenType.UNKNOWN,
|
||||
TokenType.NULL,
|
||||
*ENUM_TYPE_TOKENS,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
}
|
||||
|
@ -332,6 +337,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.INDEX,
|
||||
TokenType.ISNULL,
|
||||
TokenType.ILIKE,
|
||||
TokenType.INSERT,
|
||||
TokenType.LIKE,
|
||||
TokenType.MERGE,
|
||||
TokenType.OFFSET,
|
||||
|
@ -487,7 +493,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
|
||||
exp.Column: lambda self: self._parse_column(),
|
||||
exp.Condition: lambda self: self._parse_conjunction(),
|
||||
exp.DataType: lambda self: self._parse_types(),
|
||||
exp.DataType: lambda self: self._parse_types(allow_identifiers=False),
|
||||
exp.Expression: lambda self: self._parse_statement(),
|
||||
exp.From: lambda self: self._parse_from(),
|
||||
exp.Group: lambda self: self._parse_group(),
|
||||
|
@ -523,9 +529,6 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DESC: lambda self: self._parse_describe(),
|
||||
TokenType.DESCRIBE: lambda self: self._parse_describe(),
|
||||
TokenType.DROP: lambda self: self._parse_drop(),
|
||||
TokenType.FROM: lambda self: exp.select("*").from_(
|
||||
t.cast(exp.From, self._parse_from(skip_from_token=True))
|
||||
),
|
||||
TokenType.INSERT: lambda self: self._parse_insert(),
|
||||
TokenType.LOAD: lambda self: self._parse_load(),
|
||||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
|
@ -578,7 +581,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder),
|
||||
TokenType.PARAMETER: lambda self: self._parse_parameter(),
|
||||
TokenType.COLON: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match_set((TokenType.NUMBER, TokenType.VAR))
|
||||
if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
|
||||
else None,
|
||||
}
|
||||
|
||||
|
@ -593,6 +596,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.OVERLAPS: binary_range_parser(exp.Overlaps),
|
||||
TokenType.RLIKE: binary_range_parser(exp.RegexpLike),
|
||||
TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo),
|
||||
TokenType.FOR: lambda self, this: self._parse_comprehension(this),
|
||||
}
|
||||
|
||||
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
|
||||
|
@ -684,6 +688,12 @@ class Parser(metaclass=_Parser):
|
|||
exp.CommentColumnConstraint, this=self._parse_string()
|
||||
),
|
||||
"COMPRESS": lambda self: self._parse_compress(),
|
||||
"CLUSTERED": lambda self: self.expression(
|
||||
exp.ClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
|
||||
),
|
||||
"NONCLUSTERED": lambda self: self.expression(
|
||||
exp.NonClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered)
|
||||
),
|
||||
"DEFAULT": lambda self: self.expression(
|
||||
exp.DefaultColumnConstraint, this=self._parse_bitwise()
|
||||
),
|
||||
|
@ -698,8 +708,11 @@ class Parser(metaclass=_Parser):
|
|||
"LIKE": lambda self: self._parse_create_like(),
|
||||
"NOT": lambda self: self._parse_not_constraint(),
|
||||
"NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True),
|
||||
"ON": lambda self: self._match(TokenType.UPDATE)
|
||||
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()),
|
||||
"ON": lambda self: (
|
||||
self._match(TokenType.UPDATE)
|
||||
and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function())
|
||||
)
|
||||
or self.expression(exp.OnProperty, this=self._parse_id_var()),
|
||||
"PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()),
|
||||
"PRIMARY KEY": lambda self: self._parse_primary_key(),
|
||||
"REFERENCES": lambda self: self._parse_references(match=False),
|
||||
|
@ -709,6 +722,9 @@ class Parser(metaclass=_Parser):
|
|||
"TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]),
|
||||
"UNIQUE": lambda self: self._parse_unique(),
|
||||
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
|
||||
"WITH": lambda self: self.expression(
|
||||
exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property)
|
||||
),
|
||||
}
|
||||
|
||||
ALTER_PARSERS = {
|
||||
|
@ -728,6 +744,11 @@ class Parser(metaclass=_Parser):
|
|||
"NEXT": lambda self: self._parse_next_value_for(),
|
||||
}
|
||||
|
||||
INVALID_FUNC_NAME_TOKENS = {
|
||||
TokenType.IDENTIFIER,
|
||||
TokenType.STRING,
|
||||
}
|
||||
|
||||
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -774,6 +795,8 @@ class Parser(metaclass=_Parser):
|
|||
self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
|
||||
),
|
||||
TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)),
|
||||
TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)),
|
||||
TokenType.START_WITH: lambda self: ("connect", self._parse_connect()),
|
||||
}
|
||||
|
||||
SET_PARSERS = {
|
||||
|
@ -815,6 +838,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
|
||||
|
||||
DISTINCT_TOKENS = {TokenType.DISTINCT}
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default
|
||||
|
@ -826,6 +851,11 @@ class Parser(metaclass=_Parser):
|
|||
LOG_BASE_FIRST = True
|
||||
LOG_DEFAULTS_TO_LN = False
|
||||
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
|
||||
# Whether or not ADD is present for each column added by ALTER TABLE
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -838,9 +868,11 @@ class Parser(metaclass=_Parser):
|
|||
"_next",
|
||||
"_prev",
|
||||
"_prev_comments",
|
||||
"_tokenizer",
|
||||
)
|
||||
|
||||
# Autofilled
|
||||
TOKENIZER_CLASS: t.Type[Tokenizer] = Tokenizer
|
||||
INDEX_OFFSET: int = 0
|
||||
UNNEST_COLUMN_ONLY: bool = False
|
||||
ALIAS_POST_TABLESAMPLE: bool = False
|
||||
|
@ -863,6 +895,7 @@ class Parser(metaclass=_Parser):
|
|||
self.error_level = error_level or ErrorLevel.IMMEDIATE
|
||||
self.error_message_context = error_message_context
|
||||
self.max_errors = max_errors
|
||||
self._tokenizer = self.TOKENIZER_CLASS()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -1148,7 +1181,7 @@ class Parser(metaclass=_Parser):
|
|||
expression = self._parse_set_operations(expression) if expression else self._parse_select()
|
||||
return self._parse_query_modifiers(expression)
|
||||
|
||||
def _parse_drop(self) -> exp.Drop | exp.Command:
|
||||
def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
|
||||
start = self._prev
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match_text_seq("MATERIALIZED")
|
||||
|
@ -1160,7 +1193,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(
|
||||
exp.Drop,
|
||||
comments=start.comments,
|
||||
exists=self._parse_exists(),
|
||||
exists=exists or self._parse_exists(),
|
||||
this=self._parse_table(schema=True),
|
||||
kind=kind,
|
||||
temporary=temporary,
|
||||
|
@ -1274,6 +1307,8 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
|
||||
no_schema_binding = True
|
||||
|
||||
shallow = self._match_text_seq("SHALLOW")
|
||||
|
||||
if self._match_text_seq("CLONE"):
|
||||
clone = self._parse_table(schema=True)
|
||||
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
|
||||
|
@ -1285,7 +1320,12 @@ class Parser(metaclass=_Parser):
|
|||
clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
|
||||
self._match(TokenType.R_PAREN)
|
||||
clone = self.expression(
|
||||
exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression
|
||||
exp.Clone,
|
||||
this=clone,
|
||||
when=when,
|
||||
kind=clone_kind,
|
||||
shallow=shallow,
|
||||
expression=clone_expression,
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
|
@ -1349,7 +1389,11 @@ class Parser(metaclass=_Parser):
|
|||
if assignment:
|
||||
key = self._parse_var_or_string()
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.Property, this=key, value=self._parse_column())
|
||||
return self.expression(
|
||||
exp.Property,
|
||||
this=key,
|
||||
value=self._parse_column() or self._parse_var(any_token=True),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1409,7 +1453,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_with_property(
|
||||
self,
|
||||
) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
|
||||
) -> t.Optional[exp.Expression] | t.List[exp.Expression]:
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self._parse_wrapped_csv(self._parse_property)
|
||||
|
||||
|
@ -1622,7 +1666,7 @@ class Parser(metaclass=_Parser):
|
|||
override=override,
|
||||
)
|
||||
|
||||
def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_partition_by(self) -> t.List[exp.Expression]:
|
||||
if self._match(TokenType.PARTITION_BY):
|
||||
return self._parse_csv(self._parse_conjunction)
|
||||
return []
|
||||
|
@ -1652,9 +1696,9 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_on_property(self) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
|
||||
return exp.OnCommitProperty()
|
||||
elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
|
||||
if self._match_text_seq("COMMIT", "DELETE", "ROWS"):
|
||||
return exp.OnCommitProperty(delete=True)
|
||||
return None
|
||||
return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
|
||||
|
||||
def _parse_distkey(self) -> exp.DistKeyProperty:
|
||||
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
|
||||
|
@ -1709,8 +1753,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_describe(self) -> exp.Describe:
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
this = self._parse_table()
|
||||
return self.expression(exp.Describe, this=this, kind=kind)
|
||||
this = self._parse_table(schema=True)
|
||||
properties = self._parse_properties()
|
||||
expressions = properties.expressions if properties else None
|
||||
return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
|
||||
|
||||
def _parse_insert(self) -> exp.Insert:
|
||||
comments = ensure_list(self._prev_comments)
|
||||
|
@ -1741,6 +1787,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Insert,
|
||||
comments=comments,
|
||||
this=this,
|
||||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
exists=self._parse_exists(),
|
||||
partition=self._parse_partition(),
|
||||
where=self._match_pair(TokenType.REPLACE, TokenType.WHERE)
|
||||
|
@ -1895,6 +1942,7 @@ class Parser(metaclass=_Parser):
|
|||
"from": self._parse_from(joins=True),
|
||||
"where": self._parse_where(),
|
||||
"returning": returning or self._parse_returning(),
|
||||
"order": self._parse_order(),
|
||||
"limit": self._parse_limit(),
|
||||
},
|
||||
)
|
||||
|
@ -1948,13 +1996,14 @@ class Parser(metaclass=_Parser):
|
|||
# https://prestodb.io/docs/current/sql/values.html
|
||||
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
|
||||
|
||||
def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_projections(self) -> t.List[exp.Expression]:
|
||||
return self._parse_expressions()
|
||||
|
||||
def _parse_select(
|
||||
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
cte = self._parse_with()
|
||||
|
||||
if cte:
|
||||
this = self._parse_statement()
|
||||
|
||||
|
@ -1967,12 +2016,18 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
self.raise_error(f"{this.key} does not support CTE")
|
||||
this = cte
|
||||
elif self._match(TokenType.SELECT):
|
||||
|
||||
return this
|
||||
|
||||
# duckdb supports leading with FROM x
|
||||
from_ = self._parse_from() if self._match(TokenType.FROM, advance=False) else None
|
||||
|
||||
if self._match(TokenType.SELECT):
|
||||
comments = self._prev_comments
|
||||
|
||||
hint = self._parse_hint()
|
||||
all_ = self._match(TokenType.ALL)
|
||||
distinct = self._match(TokenType.DISTINCT)
|
||||
distinct = self._match_set(self.DISTINCT_TOKENS)
|
||||
|
||||
kind = (
|
||||
self._match(TokenType.ALIAS)
|
||||
|
@ -2006,7 +2061,9 @@ class Parser(metaclass=_Parser):
|
|||
if into:
|
||||
this.set("into", into)
|
||||
|
||||
from_ = self._parse_from()
|
||||
if not from_:
|
||||
from_ = self._parse_from()
|
||||
|
||||
if from_:
|
||||
this.set("from", from_)
|
||||
|
||||
|
@ -2033,6 +2090,8 @@ class Parser(metaclass=_Parser):
|
|||
expressions=self._parse_csv(self._parse_value),
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
elif from_:
|
||||
this = exp.select("*").from_(from_.this, copy=False)
|
||||
else:
|
||||
this = None
|
||||
|
||||
|
@ -2491,6 +2550,11 @@ class Parser(metaclass=_Parser):
|
|||
if schema:
|
||||
return self._parse_schema(this=this)
|
||||
|
||||
version = self._parse_version()
|
||||
|
||||
if version:
|
||||
this.set("version", version)
|
||||
|
||||
if self.ALIAS_POST_TABLESAMPLE:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
||||
|
@ -2498,11 +2562,11 @@ class Parser(metaclass=_Parser):
|
|||
if alias:
|
||||
this.set("alias", alias)
|
||||
|
||||
this.set("hints", self._parse_table_hints())
|
||||
|
||||
if not this.args.get("pivots"):
|
||||
this.set("pivots", self._parse_pivots())
|
||||
|
||||
this.set("hints", self._parse_table_hints())
|
||||
|
||||
if not self.ALIAS_POST_TABLESAMPLE:
|
||||
table_sample = self._parse_table_sample()
|
||||
|
||||
|
@ -2516,6 +2580,37 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_version(self) -> t.Optional[exp.Version]:
|
||||
if self._match(TokenType.TIMESTAMP_SNAPSHOT):
|
||||
this = "TIMESTAMP"
|
||||
elif self._match(TokenType.VERSION_SNAPSHOT):
|
||||
this = "VERSION"
|
||||
else:
|
||||
return None
|
||||
|
||||
if self._match_set((TokenType.FROM, TokenType.BETWEEN)):
|
||||
kind = self._prev.text.upper()
|
||||
start = self._parse_bitwise()
|
||||
self._match_texts(("TO", "AND"))
|
||||
end = self._parse_bitwise()
|
||||
expression: t.Optional[exp.Expression] = self.expression(
|
||||
exp.Tuple, expressions=[start, end]
|
||||
)
|
||||
elif self._match_text_seq("CONTAINED", "IN"):
|
||||
kind = "CONTAINED IN"
|
||||
expression = self.expression(
|
||||
exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise)
|
||||
)
|
||||
elif self._match(TokenType.ALL):
|
||||
kind = "ALL"
|
||||
expression = None
|
||||
else:
|
||||
self._match_text_seq("AS", "OF")
|
||||
kind = "AS OF"
|
||||
expression = self._parse_type()
|
||||
|
||||
return self.expression(exp.Version, this=this, expression=expression, kind=kind)
|
||||
|
||||
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
|
||||
if not self._match(TokenType.UNNEST):
|
||||
return None
|
||||
|
@ -2760,7 +2855,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Group, **elements) # type: ignore
|
||||
|
||||
def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
def _parse_grouping_sets(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.GROUPING_SETS):
|
||||
return None
|
||||
|
||||
|
@ -2784,6 +2879,22 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
return self.expression(exp.Qualify, this=self._parse_conjunction())
|
||||
|
||||
def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]:
|
||||
if skip_start_token:
|
||||
start = None
|
||||
elif self._match(TokenType.START_WITH):
|
||||
start = self._parse_conjunction()
|
||||
else:
|
||||
return None
|
||||
|
||||
self._match(TokenType.CONNECT_BY)
|
||||
self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression(
|
||||
exp.Prior, this=self._parse_bitwise()
|
||||
)
|
||||
connect = self._parse_conjunction()
|
||||
self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR")
|
||||
return self.expression(exp.Connect, start=start, connect=connect)
|
||||
|
||||
def _parse_order(
|
||||
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -2929,6 +3040,7 @@ class Parser(metaclass=_Parser):
|
|||
expression,
|
||||
this=this,
|
||||
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
||||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
expression=self._parse_set_operations(self._parse_select(nested=True)),
|
||||
)
|
||||
|
||||
|
@ -3017,6 +3129,8 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Escape, this=this, expression=self._parse_string())
|
||||
|
||||
def _parse_interval(self) -> t.Optional[exp.Interval]:
|
||||
index = self._index
|
||||
|
||||
if not self._match(TokenType.INTERVAL):
|
||||
return None
|
||||
|
||||
|
@ -3025,7 +3139,11 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
this = self._parse_term()
|
||||
|
||||
unit = self._parse_function() or self._parse_var()
|
||||
if not this:
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
unit = self._parse_function() or self._parse_var(any_token=True)
|
||||
|
||||
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
|
||||
# each INTERVAL expression into this canonical form so it's easy to transpile
|
||||
|
@ -3036,12 +3154,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if len(parts) == 2:
|
||||
if unit:
|
||||
# this is not actually a unit, it's something else
|
||||
# This is not actually a unit, it's something else (e.g. a "window side")
|
||||
unit = None
|
||||
self._retreat(self._index - 1)
|
||||
else:
|
||||
this = exp.Literal.string(parts[0])
|
||||
unit = self.expression(exp.Var, this=parts[1])
|
||||
|
||||
this = exp.Literal.string(parts[0])
|
||||
unit = self.expression(exp.Var, this=parts[1])
|
||||
|
||||
return self.expression(exp.Interval, this=this, unit=unit)
|
||||
|
||||
|
@ -3087,7 +3205,7 @@ class Parser(metaclass=_Parser):
|
|||
return interval
|
||||
|
||||
index = self._index
|
||||
data_type = self._parse_types(check_func=True)
|
||||
data_type = self._parse_types(check_func=True, allow_identifiers=False)
|
||||
this = self._parse_column()
|
||||
|
||||
if data_type:
|
||||
|
@ -3103,30 +3221,50 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]:
|
||||
def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
|
||||
this = self._parse_type()
|
||||
if not this:
|
||||
return None
|
||||
|
||||
return self.expression(
|
||||
exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
|
||||
exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True)
|
||||
)
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
|
||||
) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
|
||||
prefix = self._match_text_seq("SYSUDTLIB", ".")
|
||||
|
||||
if not self._match_set(self.TYPE_TOKENS):
|
||||
return None
|
||||
identifier = allow_identifiers and self._parse_id_var(
|
||||
any_token=False, tokens=(TokenType.VAR,)
|
||||
)
|
||||
|
||||
if identifier:
|
||||
tokens = self._tokenizer.tokenize(identifier.name)
|
||||
|
||||
if len(tokens) != 1:
|
||||
self.raise_error("Unexpected identifier", self._prev)
|
||||
|
||||
if tokens[0].token_type in self.TYPE_TOKENS:
|
||||
self._prev = tokens[0]
|
||||
elif self.SUPPORTS_USER_DEFINED_TYPES:
|
||||
return identifier
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
type_token = self._prev.token_type
|
||||
|
||||
if type_token == TokenType.PSEUDO_TYPE:
|
||||
return self.expression(exp.PseudoType, this=self._prev.text)
|
||||
|
||||
if type_token == TokenType.OBJECT_IDENTIFIER:
|
||||
return self.expression(exp.ObjectIdentifier, this=self._prev.text)
|
||||
|
||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token in self.STRUCT_TYPE_TOKENS
|
||||
expressions = None
|
||||
|
@ -3137,7 +3275,9 @@ class Parser(metaclass=_Parser):
|
|||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
elif nested:
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(check_func=check_func, schema=schema)
|
||||
lambda: self._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
)
|
||||
elif type_token in self.ENUM_TYPE_TOKENS:
|
||||
expressions = self._parse_csv(self._parse_equality)
|
||||
|
@ -3151,14 +3291,16 @@ class Parser(metaclass=_Parser):
|
|||
maybe_func = True
|
||||
|
||||
this: t.Optional[exp.Expression] = None
|
||||
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
|
||||
values: t.Optional[t.List[exp.Expression]] = None
|
||||
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
else:
|
||||
expressions = self._parse_csv(
|
||||
lambda: self._parse_types(check_func=check_func, schema=schema)
|
||||
lambda: self._parse_types(
|
||||
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
|
||||
)
|
||||
)
|
||||
|
||||
if not self._match(TokenType.GT):
|
||||
|
@ -3355,7 +3497,7 @@ class Parser(metaclass=_Parser):
|
|||
upper = this.upper()
|
||||
|
||||
parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper)
|
||||
if optional_parens and parser:
|
||||
if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS:
|
||||
self._advance()
|
||||
return parser(self)
|
||||
|
||||
|
@ -3442,7 +3584,9 @@ class Parser(metaclass=_Parser):
|
|||
index = self._index
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
expressions = self._parse_csv(self._parse_id_var)
|
||||
expressions = t.cast(
|
||||
t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_id_var)
|
||||
)
|
||||
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
self._retreat(index)
|
||||
|
@ -3481,14 +3625,14 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.L_PAREN):
|
||||
return this
|
||||
|
||||
args = self._parse_csv(
|
||||
lambda: self._parse_constraint()
|
||||
or self._parse_column_def(self._parse_field(any_token=True))
|
||||
)
|
||||
args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def())
|
||||
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Schema, this=this, expressions=args)
|
||||
|
||||
def _parse_field_def(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_column_def(self._parse_field(any_token=True))
|
||||
|
||||
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
# column defs are not really columns, they're identifiers
|
||||
if isinstance(this, exp.Column):
|
||||
|
@ -3499,7 +3643,18 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_text_seq("FOR", "ORDINALITY"):
|
||||
return self.expression(exp.ColumnDef, this=this, ordinality=True)
|
||||
|
||||
constraints = []
|
||||
constraints: t.List[exp.Expression] = []
|
||||
|
||||
if not kind and self._match(TokenType.ALIAS):
|
||||
constraints.append(
|
||||
self.expression(
|
||||
exp.ComputedColumnConstraint,
|
||||
this=self._parse_conjunction(),
|
||||
persisted=self._match_text_seq("PERSISTED"),
|
||||
not_null=self._match_pair(TokenType.NOT, TokenType.NULL),
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
constraint = self._parse_column_constraint()
|
||||
if not constraint:
|
||||
|
@ -3553,7 +3708,7 @@ class Parser(metaclass=_Parser):
|
|||
identity = self._match_text_seq("IDENTITY")
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
if self._match_text_seq("START", "WITH"):
|
||||
if self._match(TokenType.START_WITH):
|
||||
this.set("start", self._parse_bitwise())
|
||||
if self._match_text_seq("INCREMENT", "BY"):
|
||||
this.set("increment", self._parse_bitwise())
|
||||
|
@ -3580,11 +3735,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_not_constraint(
|
||||
self,
|
||||
) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]:
|
||||
) -> t.Optional[exp.Expression]:
|
||||
if self._match_text_seq("NULL"):
|
||||
return self.expression(exp.NotNullColumnConstraint)
|
||||
if self._match_text_seq("CASESPECIFIC"):
|
||||
return self.expression(exp.CaseSpecificColumnConstraint, not_=True)
|
||||
if self._match_text_seq("FOR", "REPLICATION"):
|
||||
return self.expression(exp.NotForReplicationColumnConstraint)
|
||||
return None
|
||||
|
||||
def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -3729,7 +3886,7 @@ class Parser(metaclass=_Parser):
|
|||
bracket_kind = self._prev.token_type
|
||||
|
||||
if self._match(TokenType.COLON):
|
||||
expressions: t.List[t.Optional[exp.Expression]] = [
|
||||
expressions: t.List[exp.Expression] = [
|
||||
self.expression(exp.Slice, expression=self._parse_conjunction())
|
||||
]
|
||||
else:
|
||||
|
@ -3844,17 +4001,17 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if not self._match(TokenType.ALIAS):
|
||||
if self._match(TokenType.COMMA):
|
||||
return self.expression(
|
||||
exp.CastToStrType, this=this, expression=self._parse_string()
|
||||
)
|
||||
else:
|
||||
self.raise_error("Expected AS after CAST")
|
||||
return self.expression(exp.CastToStrType, this=this, to=self._parse_string())
|
||||
|
||||
self.raise_error("Expected AS after CAST")
|
||||
|
||||
fmt = None
|
||||
to = self._parse_types()
|
||||
|
||||
if not to:
|
||||
self.raise_error("Expected TYPE after CAST")
|
||||
elif isinstance(to, exp.Identifier):
|
||||
to = exp.DataType.build(to.name, udt=True)
|
||||
elif to.this == exp.DataType.Type.CHAR:
|
||||
if self._match(TokenType.CHARACTER_SET):
|
||||
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
|
||||
|
@ -3908,7 +4065,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.COMMA):
|
||||
args.extend(self._parse_csv(self._parse_conjunction))
|
||||
else:
|
||||
args = self._parse_csv(self._parse_conjunction)
|
||||
args = self._parse_csv(self._parse_conjunction) # type: ignore
|
||||
|
||||
index = self._index
|
||||
if not self._match(TokenType.R_PAREN) and args:
|
||||
|
@ -3991,10 +4148,10 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
|
||||
self._match_text_seq("KEY")
|
||||
key = self._parse_field()
|
||||
self._match(TokenType.COLON)
|
||||
key = self._parse_column()
|
||||
self._match_set((TokenType.COLON, TokenType.COMMA))
|
||||
self._match_text_seq("VALUE")
|
||||
value = self._parse_field()
|
||||
value = self._parse_bitwise()
|
||||
|
||||
if not key and not value:
|
||||
return None
|
||||
|
@ -4116,7 +4273,7 @@ class Parser(metaclass=_Parser):
|
|||
# Postgres supports the form: substring(string [from int] [for int])
|
||||
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
|
||||
|
||||
args = self._parse_csv(self._parse_bitwise)
|
||||
args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise))
|
||||
|
||||
if self._match(TokenType.FROM):
|
||||
args.append(self._parse_bitwise())
|
||||
|
@ -4149,7 +4306,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Trim, this=this, position=position, expression=expression, collation=collation
|
||||
)
|
||||
|
||||
def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
|
||||
|
||||
def _parse_named_window(self) -> t.Optional[exp.Expression]:
|
||||
|
@ -4216,8 +4373,7 @@ class Parser(metaclass=_Parser):
|
|||
if self._match_text_seq("LAST"):
|
||||
first = False
|
||||
|
||||
partition = self._parse_partition_by()
|
||||
order = self._parse_order()
|
||||
partition, order = self._parse_partition_and_order()
|
||||
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
|
||||
|
||||
if kind:
|
||||
|
@ -4256,6 +4412,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return window
|
||||
|
||||
def _parse_partition_and_order(
|
||||
self,
|
||||
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
|
||||
return self._parse_partition_by(), self._parse_order()
|
||||
|
||||
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
|
||||
self._match(TokenType.BETWEEN)
|
||||
|
||||
|
@ -4377,14 +4538,14 @@ class Parser(metaclass=_Parser):
|
|||
self._advance(-1)
|
||||
return None
|
||||
|
||||
def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.EXCEPT):
|
||||
return None
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
return self._parse_wrapped_csv(self._parse_column)
|
||||
return self._parse_csv(self._parse_column)
|
||||
|
||||
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if not self._match(TokenType.REPLACE):
|
||||
return None
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
|
@ -4393,7 +4554,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_csv(
|
||||
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
) -> t.List[exp.Expression]:
|
||||
parse_result = parse_method()
|
||||
items = [parse_result] if parse_result is not None else []
|
||||
|
||||
|
@ -4420,12 +4581,12 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
|
||||
return self._parse_wrapped_csv(self._parse_id_var, optional=optional)
|
||||
|
||||
def _parse_wrapped_csv(
|
||||
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False
|
||||
) -> t.List[t.Optional[exp.Expression]]:
|
||||
) -> t.List[exp.Expression]:
|
||||
return self._parse_wrapped(
|
||||
lambda: self._parse_csv(parse_method, sep=sep), optional=optional
|
||||
)
|
||||
|
@ -4439,7 +4600,7 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
return parse_result
|
||||
|
||||
def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_expressions(self) -> t.List[exp.Expression]:
|
||||
return self._parse_csv(self._parse_expression)
|
||||
|
||||
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
|
||||
|
@ -4498,7 +4659,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._match(TokenType.COLUMN)
|
||||
exists_column = self._parse_exists(not_=True)
|
||||
expression = self._parse_column_def(self._parse_field(any_token=True))
|
||||
expression = self._parse_field_def()
|
||||
|
||||
if expression:
|
||||
expression.set("exists", exists_column)
|
||||
|
@ -4549,13 +4710,16 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.AddConstraint, this=this, expression=expression)
|
||||
|
||||
def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_alter_table_add(self) -> t.List[exp.Expression]:
|
||||
index = self._index - 1
|
||||
|
||||
if self._match_set(self.ADD_CONSTRAINT_TOKENS):
|
||||
return self._parse_csv(self._parse_add_constraint)
|
||||
|
||||
self._retreat(index)
|
||||
if not self.ALTER_TABLE_ADD_COLUMN_KEYWORD and self._match_text_seq("ADD"):
|
||||
return self._parse_csv(self._parse_field_def)
|
||||
|
||||
return self._parse_csv(self._parse_add_column)
|
||||
|
||||
def _parse_alter_table_alter(self) -> exp.AlterColumn:
|
||||
|
@ -4576,7 +4740,7 @@ class Parser(metaclass=_Parser):
|
|||
using=self._match(TokenType.USING) and self._parse_conjunction(),
|
||||
)
|
||||
|
||||
def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
|
||||
def _parse_alter_table_drop(self) -> t.List[exp.Expression]:
|
||||
index = self._index - 1
|
||||
|
||||
partition_exists = self._parse_exists()
|
||||
|
@ -4619,6 +4783,9 @@ class Parser(metaclass=_Parser):
|
|||
self._match(TokenType.INTO)
|
||||
target = self._parse_table()
|
||||
|
||||
if target and self._match(TokenType.ALIAS, advance=False):
|
||||
target.set("alias", self._parse_table_alias())
|
||||
|
||||
self._match(TokenType.USING)
|
||||
using = self._parse_table()
|
||||
|
||||
|
@ -4685,8 +4852,7 @@ class Parser(metaclass=_Parser):
|
|||
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
|
||||
if parser:
|
||||
return parser(self)
|
||||
self._advance()
|
||||
return self.expression(exp.Show, this=self._prev.text.upper())
|
||||
return self._parse_as_command(self._prev)
|
||||
|
||||
def _parse_set_item_assignment(
|
||||
self, kind: t.Optional[str] = None
|
||||
|
@ -4786,6 +4952,19 @@ class Parser(metaclass=_Parser):
|
|||
self._match_r_paren()
|
||||
return self.expression(exp.DictRange, this=this, min=min, max=max)
|
||||
|
||||
def _parse_comprehension(self, this: exp.Expression) -> exp.Comprehension:
|
||||
expression = self._parse_column()
|
||||
self._match(TokenType.IN)
|
||||
iterator = self._parse_column()
|
||||
condition = self._parse_conjunction() if self._match_text_seq("IF") else None
|
||||
return self.expression(
|
||||
exp.Comprehension,
|
||||
this=this,
|
||||
expression=expression,
|
||||
iterator=iterator,
|
||||
condition=condition,
|
||||
)
|
||||
|
||||
def _find_parser(
|
||||
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
|
||||
) -> t.Optional[t.Callable]:
|
||||
|
|
|
@ -48,6 +48,7 @@ class TokenType(AutoName):
|
|||
HASH_ARROW = auto()
|
||||
DHASH_ARROW = auto()
|
||||
LR_ARROW = auto()
|
||||
DAT = auto()
|
||||
LT_AT = auto()
|
||||
AT_GT = auto()
|
||||
DOLLAR = auto()
|
||||
|
@ -84,6 +85,7 @@ class TokenType(AutoName):
|
|||
UTINYINT = auto()
|
||||
SMALLINT = auto()
|
||||
USMALLINT = auto()
|
||||
MEDIUMINT = auto()
|
||||
INT = auto()
|
||||
UINT = auto()
|
||||
BIGINT = auto()
|
||||
|
@ -140,6 +142,7 @@ class TokenType(AutoName):
|
|||
SMALLSERIAL = auto()
|
||||
BIGSERIAL = auto()
|
||||
XML = auto()
|
||||
YEAR = auto()
|
||||
UNIQUEIDENTIFIER = auto()
|
||||
USERDEFINED = auto()
|
||||
MONEY = auto()
|
||||
|
@ -157,6 +160,7 @@ class TokenType(AutoName):
|
|||
FIXEDSTRING = auto()
|
||||
LOWCARDINALITY = auto()
|
||||
NESTED = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
# keywords
|
||||
ALIAS = auto()
|
||||
|
@ -180,6 +184,7 @@ class TokenType(AutoName):
|
|||
COMMAND = auto()
|
||||
COMMENT = auto()
|
||||
COMMIT = auto()
|
||||
CONNECT_BY = auto()
|
||||
CONSTRAINT = auto()
|
||||
CREATE = auto()
|
||||
CROSS = auto()
|
||||
|
@ -256,6 +261,7 @@ class TokenType(AutoName):
|
|||
NEXT = auto()
|
||||
NOTNULL = auto()
|
||||
NULL = auto()
|
||||
OBJECT_IDENTIFIER = auto()
|
||||
OFFSET = auto()
|
||||
ON = auto()
|
||||
ORDER_BY = auto()
|
||||
|
@ -298,6 +304,7 @@ class TokenType(AutoName):
|
|||
SIMILAR_TO = auto()
|
||||
SOME = auto()
|
||||
SORT_BY = auto()
|
||||
START_WITH = auto()
|
||||
STRUCT = auto()
|
||||
TABLE_SAMPLE = auto()
|
||||
TEMPORARY = auto()
|
||||
|
@ -319,6 +326,8 @@ class TokenType(AutoName):
|
|||
WINDOW = auto()
|
||||
WITH = auto()
|
||||
UNIQUE = auto()
|
||||
VERSION_SNAPSHOT = auto()
|
||||
TIMESTAMP_SNAPSHOT = auto()
|
||||
|
||||
|
||||
class Token:
|
||||
|
@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"COLLATE": TokenType.COLLATE,
|
||||
"COLUMN": TokenType.COLUMN,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
"CONNECT BY": TokenType.CONNECT_BY,
|
||||
"CONSTRAINT": TokenType.CONSTRAINT,
|
||||
"CREATE": TokenType.CREATE,
|
||||
"CROSS": TokenType.CROSS,
|
||||
|
@ -636,6 +646,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
"SOME": TokenType.SOME,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
"START WITH": TokenType.START_WITH,
|
||||
"TABLE": TokenType.TABLE,
|
||||
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
|
@ -643,6 +654,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"THEN": TokenType.THEN,
|
||||
"TRUE": TokenType.TRUE,
|
||||
"UNION": TokenType.UNION,
|
||||
"UNKNOWN": TokenType.UNKNOWN,
|
||||
"UNNEST": TokenType.UNNEST,
|
||||
"UNPIVOT": TokenType.UNPIVOT,
|
||||
"UPDATE": TokenType.UPDATE,
|
||||
|
@ -739,6 +751,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TRUNCATE": TokenType.COMMAND,
|
||||
"VACUUM": TokenType.COMMAND,
|
||||
"USER-DEFINED": TokenType.USERDEFINED,
|
||||
"FOR VERSION": TokenType.VERSION_SNAPSHOT,
|
||||
"FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT,
|
||||
}
|
||||
|
||||
WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
|
||||
|
@ -941,8 +955,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if result == TrieResult.EXISTS:
|
||||
word = chars
|
||||
|
||||
end = self._current + size
|
||||
size += 1
|
||||
end = self._current - 1 + size
|
||||
|
||||
if end < self.size:
|
||||
char = self.sql[end]
|
||||
|
@ -961,21 +975,20 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
char = ""
|
||||
chars = " "
|
||||
|
||||
if not word:
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
|
||||
if word:
|
||||
if self._scan_string(word):
|
||||
return
|
||||
self._scan_var()
|
||||
if self._scan_comment(word):
|
||||
return
|
||||
if prev_space or single_token or not char:
|
||||
self._advance(size - 1)
|
||||
word = word.upper()
|
||||
self._add(self.KEYWORDS[word], text=word)
|
||||
return
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
|
||||
return
|
||||
|
||||
if self._scan_string(word):
|
||||
return
|
||||
if self._scan_comment(word):
|
||||
return
|
||||
|
||||
self._advance(size - 1)
|
||||
word = word.upper()
|
||||
self._add(self.KEYWORDS[word], text=word)
|
||||
self._scan_var()
|
||||
|
||||
def _scan_comment(self, comment_start: str) -> bool:
|
||||
if comment_start not in self._COMMENTS:
|
||||
|
@ -1053,8 +1066,8 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
return self._advance(-len(literal))
|
||||
self._advance(-len(literal))
|
||||
return self._add(TokenType.NUMBER, number_text)
|
||||
else:
|
||||
return self._add(TokenType.NUMBER)
|
||||
|
||||
|
|
|
@ -68,11 +68,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
if order:
|
||||
window.set("order", order.pop().copy())
|
||||
else:
|
||||
window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
|
||||
|
||||
window = exp.alias_(window, row_number)
|
||||
expression.select(window, copy=False)
|
||||
|
||||
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
|
||||
return (
|
||||
exp.select(*outer_selects)
|
||||
.from_(expression.subquery())
|
||||
.where(exp.column(row_number).eq(1))
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
@ -126,7 +132,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
|
|||
"""
|
||||
for node in expression.find_all(exp.DataType):
|
||||
node.set(
|
||||
"expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
|
||||
"expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
|
||||
)
|
||||
|
||||
return expression
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue