Merging upstream version 17.12.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
aa315e6009
commit
aae08e0bb3
64 changed files with 12465 additions and 11885 deletions
|
@ -60,6 +60,7 @@ from sqlglot.dialects.bigquery import BigQuery
|
|||
from sqlglot.dialects.clickhouse import ClickHouse
|
||||
from sqlglot.dialects.databricks import Databricks
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.dialects.doris import Doris
|
||||
from sqlglot.dialects.drill import Drill
|
||||
from sqlglot.dialects.duckdb import DuckDB
|
||||
from sqlglot.dialects.hive import Hive
|
||||
|
|
|
@ -37,17 +37,22 @@ class ClickHouse(Dialect):
|
|||
"ATTACH": TokenType.COMMAND,
|
||||
"DATETIME64": TokenType.DATETIME64,
|
||||
"DICTIONARY": TokenType.DICTIONARY,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"ENUM8": TokenType.ENUM8,
|
||||
"ENUM16": TokenType.ENUM16,
|
||||
"FINAL": TokenType.FINAL,
|
||||
"FIXEDSTRING": TokenType.FIXEDSTRING,
|
||||
"FLOAT32": TokenType.FLOAT,
|
||||
"FLOAT64": TokenType.DOUBLE,
|
||||
"GLOBAL": TokenType.GLOBAL,
|
||||
"INT128": TokenType.INT128,
|
||||
"INT16": TokenType.SMALLINT,
|
||||
"INT256": TokenType.INT256,
|
||||
"INT32": TokenType.INT,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"INT8": TokenType.TINYINT,
|
||||
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
|
||||
"MAP": TokenType.MAP,
|
||||
"NESTED": TokenType.NESTED,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
|
@ -294,11 +299,17 @@ class ClickHouse(Dialect):
|
|||
exp.DataType.Type.BIGINT: "Int64",
|
||||
exp.DataType.Type.DATETIME64: "DateTime64",
|
||||
exp.DataType.Type.DOUBLE: "Float64",
|
||||
exp.DataType.Type.ENUM: "Enum",
|
||||
exp.DataType.Type.ENUM8: "Enum8",
|
||||
exp.DataType.Type.ENUM16: "Enum16",
|
||||
exp.DataType.Type.FIXEDSTRING: "FixedString",
|
||||
exp.DataType.Type.FLOAT: "Float32",
|
||||
exp.DataType.Type.INT: "Int32",
|
||||
exp.DataType.Type.INT128: "Int128",
|
||||
exp.DataType.Type.INT256: "Int256",
|
||||
exp.DataType.Type.LOWCARDINALITY: "LowCardinality",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
exp.DataType.Type.NESTED: "Nested",
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.SMALLINT: "Int16",
|
||||
exp.DataType.Type.STRUCT: "Tuple",
|
||||
|
|
|
@ -39,6 +39,7 @@ class Dialects(str, Enum):
|
|||
TERADATA = "teradata"
|
||||
TRINO = "trino"
|
||||
TSQL = "tsql"
|
||||
Doris = "doris"
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
|
@ -121,7 +122,7 @@ class _Dialect(type):
|
|||
if hasattr(subclass, name):
|
||||
setattr(subclass, name, value)
|
||||
|
||||
if not klass.STRICT_STRING_CONCAT:
|
||||
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
|
||||
klass.generator_class.can_identify = klass.can_identify
|
||||
|
@ -146,6 +147,9 @@ class Dialect(metaclass=_Dialect):
|
|||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
|
||||
# Determines whether or not the DPIPE token ('||') is a string concatenation operator
|
||||
DPIPE_IS_STRING_CONCAT = True
|
||||
|
||||
# Determines whether or not CONCAT's arguments must be strings
|
||||
STRICT_STRING_CONCAT = False
|
||||
|
||||
|
@ -460,6 +464,20 @@ def format_time_lambda(
|
|||
return _format_time
|
||||
|
||||
|
||||
def time_format(
|
||||
dialect: DialectType = None,
|
||||
) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
|
||||
def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
|
||||
"""
|
||||
Returns the time format for a given expression, unless it's equivalent
|
||||
to the default time format of the dialect of interest.
|
||||
"""
|
||||
time_format = self.format_time(expression)
|
||||
return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
|
||||
|
||||
return _time_format
|
||||
|
||||
|
||||
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
||||
"""
|
||||
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
||||
|
@ -699,3 +717,8 @@ def simplify_literal(expression: E) -> E:
|
|||
|
||||
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
|
||||
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
|
||||
|
||||
|
||||
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
|
||||
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
|
||||
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
|
65
sqlglot/dialects/doris.py
Normal file
65
sqlglot/dialects/doris.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import (
|
||||
approx_count_distinct_sql,
|
||||
arrow_json_extract_sql,
|
||||
parse_timestamp_trunc,
|
||||
rename_func,
|
||||
time_format,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
||||
|
||||
class Doris(MySQL):
|
||||
DATE_FORMAT = "'yyyy-MM-dd'"
|
||||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
|
||||
class Parser(MySQL.Parser):
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"REGEXP": exp.RegexpLike.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(MySQL.Generator):
|
||||
CAST_MAPPING = {}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**MySQL.Generator.TRANSFORMS,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.ArrayAgg: rename_func("COLLECT_LIST"),
|
||||
exp.Coalesce: rename_func("NVL"),
|
||||
exp.CurrentTimestamp: lambda *_: "NOW()",
|
||||
exp.DateTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
),
|
||||
exp.JSONExtractScalar: arrow_json_extract_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP"),
|
||||
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Split: rename_func("SPLIT_BY_STRING"),
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
|
||||
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimestampTrunc: lambda self, e: self.func(
|
||||
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
|
||||
),
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.Map: rename_func("ARRAY_MAP"),
|
||||
}
|
|
@ -89,6 +89,11 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
|||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
if expression.is_type("array"):
|
||||
return f"{self.expressions(expression, flat=True)}[]"
|
||||
|
||||
# Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers
|
||||
if expression.is_type("timestamptz", "timetz"):
|
||||
return expression.this.value
|
||||
|
||||
return self.datatype_sql(expression)
|
||||
|
||||
|
||||
|
@ -110,14 +115,14 @@ class DuckDB(Dialect):
|
|||
"//": TokenType.DIV,
|
||||
"ATTACH": TokenType.COMMAND,
|
||||
"BINARY": TokenType.VARBINARY,
|
||||
"BPCHAR": TokenType.TEXT,
|
||||
"BITSTRING": TokenType.BIT,
|
||||
"BPCHAR": TokenType.TEXT,
|
||||
"CHAR": TokenType.TEXT,
|
||||
"CHARACTER VARYING": TokenType.TEXT,
|
||||
"EXCLUDE": TokenType.EXCEPT,
|
||||
"HUGEINT": TokenType.INT128,
|
||||
"INT1": TokenType.TINYINT,
|
||||
"LOGICAL": TokenType.BOOLEAN,
|
||||
"NUMERIC": TokenType.DOUBLE,
|
||||
"PIVOT_WIDER": TokenType.PIVOT,
|
||||
"SIGNED": TokenType.INT,
|
||||
"STRING": TokenType.VARCHAR,
|
||||
|
@ -186,6 +191,22 @@ class DuckDB(Dialect):
|
|||
TokenType.UTINYINT,
|
||||
}
|
||||
|
||||
def _parse_types(
|
||||
self, check_func: bool = False, schema: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
this = super()._parse_types(check_func=check_func, schema=schema)
|
||||
|
||||
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
|
||||
# See: https://duckdb.org/docs/sql/data_types/numeric
|
||||
if (
|
||||
isinstance(this, exp.DataType)
|
||||
and this.is_type("numeric", "decimal")
|
||||
and not this.expressions
|
||||
):
|
||||
return exp.DataType.build("DECIMAL(18, 3)")
|
||||
|
||||
return this
|
||||
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
if len(aggregations) == 1:
|
||||
return super()._pivot_column_names(aggregations)
|
||||
|
@ -231,6 +252,7 @@ class DuckDB(Dialect):
|
|||
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.IsNan: rename_func("ISNAN"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONFormat: _json_format_sql,
|
||||
|
|
|
@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
|
|||
right_to_substring_sql,
|
||||
strposition_to_locate_sql,
|
||||
struct_extract_sql,
|
||||
time_format,
|
||||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
|
@ -113,7 +114,7 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
|
|||
|
||||
|
||||
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
|
||||
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
|
||||
|
||||
|
||||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
|
@ -132,15 +133,6 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
|
|||
return f"CAST({this} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def _time_format(
|
||||
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
|
||||
) -> t.Optional[str]:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.TIME_FORMAT:
|
||||
return None
|
||||
return time_format
|
||||
|
||||
|
||||
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
|
@ -439,7 +431,7 @@ class Hive(Dialect):
|
|||
exp.TsOrDsToDate: _to_date_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.UnixToStr: lambda self, e: self.func(
|
||||
"FROM_UNIXTIME", e.this, _time_format(self, e)
|
||||
"FROM_UNIXTIME", e.this, time_format("hive")(self, e)
|
||||
),
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
|
||||
|
|
|
@ -94,6 +94,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
|
||||
class MySQL(Dialect):
|
||||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
DPIPE_IS_STRING_CONCAT = False
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
TIME_MAPPING = {
|
||||
|
@ -103,7 +104,6 @@ class MySQL(Dialect):
|
|||
"%h": "%I",
|
||||
"%i": "%M",
|
||||
"%s": "%S",
|
||||
"%S": "%S",
|
||||
"%u": "%W",
|
||||
"%k": "%-H",
|
||||
"%l": "%-I",
|
||||
|
@ -196,8 +196,14 @@ class MySQL(Dialect):
|
|||
**parser.Parser.CONJUNCTION,
|
||||
TokenType.DAMP: exp.And,
|
||||
TokenType.XOR: exp.Xor,
|
||||
TokenType.DPIPE: exp.Or,
|
||||
}
|
||||
|
||||
# MySQL uses || as a synonym to the logical OR operator
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or
|
||||
BITWISE = parser.Parser.BITWISE.copy()
|
||||
BITWISE.pop(TokenType.DPIPE)
|
||||
|
||||
TABLE_ALIAS_TOKENS = (
|
||||
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
|
||||
)
|
||||
|
|
|
@ -16,6 +16,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_pivot_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
parse_timestamp_trunc,
|
||||
rename_func,
|
||||
simplify_literal,
|
||||
str_position_sql,
|
||||
|
@ -286,9 +287,7 @@ class Postgres(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"GENERATE_SERIES": _generate_series,
|
||||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
|
|
|
@ -32,13 +32,6 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
|
|||
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
|
||||
|
||||
|
||||
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
|
||||
sql = self.datatype_sql(expression)
|
||||
if expression.is_type("timestamptz"):
|
||||
sql = f"{sql} WITH TIME ZONE"
|
||||
return sql
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
expression = expression.copy()
|
||||
|
@ -231,6 +224,7 @@ class Presto(Dialect):
|
|||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
IS_BOOL_ALLOWED = False
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -245,6 +239,7 @@ class Presto(Dialect):
|
|||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
exp.DataType.Type.TEXT: "VARCHAR",
|
||||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.STRUCT: "ROW",
|
||||
}
|
||||
|
@ -265,7 +260,6 @@ class Presto(Dialect):
|
|||
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
|
||||
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
|
|
|
@ -85,8 +85,6 @@ class Redshift(Postgres):
|
|||
"HLLSKETCH": TokenType.HLLSKETCH,
|
||||
"SUPER": TokenType.SUPER,
|
||||
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
|
||||
"TIME": TokenType.TIMESTAMP,
|
||||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"TOP": TokenType.TOP,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
|
@ -101,12 +99,15 @@ class Redshift(Postgres):
|
|||
RENAME_TABLE_WITH_DB = False
|
||||
QUERY_HINTS = False
|
||||
VALUES_AS_TABLE = False
|
||||
TZ_TO_WITH_TIME_ZONE = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**Postgres.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.BINARY: "VARBYTE",
|
||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.TIMETZ: "TIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "VARBYTE",
|
||||
}
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
|
|
@ -52,6 +52,9 @@ class Spark(Spark2):
|
|||
TRANSFORMS = {
|
||||
**Spark2.Generator.TRANSFORMS,
|
||||
exp.StartsWith: rename_func("STARTSWITH"),
|
||||
exp.TimestampAdd: lambda self, e: self.func(
|
||||
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
|
||||
),
|
||||
}
|
||||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
|
|
@ -4,6 +4,7 @@ from sqlglot import exp
|
|||
from sqlglot.dialects.dialect import (
|
||||
approx_count_distinct_sql,
|
||||
arrow_json_extract_sql,
|
||||
parse_timestamp_trunc,
|
||||
rename_func,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
@ -14,9 +15,7 @@ class StarRocks(MySQL):
|
|||
class Parser(MySQL.Parser):
|
||||
FUNCTIONS = {
|
||||
**MySQL.Parser.FUNCTIONS,
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_TRUNC": parse_timestamp_trunc,
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
|
||||
),
|
||||
|
|
|
@ -28,6 +28,11 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.schema import Schema
|
||||
|
||||
|
||||
PYTHON_TYPE_TO_SQLGLOT = {
|
||||
"dict": "MAP",
|
||||
}
|
||||
|
||||
|
||||
def execute(
|
||||
sql: str | Expression,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
|
@ -50,7 +55,7 @@ def execute(
|
|||
Returns:
|
||||
Simple columnar data structure.
|
||||
"""
|
||||
tables_ = ensure_tables(tables)
|
||||
tables_ = ensure_tables(tables, dialect=read)
|
||||
|
||||
if not schema:
|
||||
schema = {}
|
||||
|
@ -61,7 +66,8 @@ def execute(
|
|||
assert table is not None
|
||||
|
||||
for column in table.columns:
|
||||
nested_set(schema, [*keys, column], type(table[0][column]).__name__)
|
||||
py_type = type(table[0][column]).__name__
|
||||
nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
|
||||
|
||||
schema = ensure_schema(schema, dialect=read)
|
||||
|
||||
|
|
|
@ -2,8 +2,9 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
from sqlglot.helper import dict_depth
|
||||
from sqlglot.schema import AbstractMappingSchema
|
||||
from sqlglot.schema import AbstractMappingSchema, normalize_name
|
||||
|
||||
|
||||
class Table:
|
||||
|
@ -108,26 +109,37 @@ class Tables(AbstractMappingSchema[Table]):
|
|||
pass
|
||||
|
||||
|
||||
def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
|
||||
return Tables(_ensure_tables(d))
|
||||
def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
|
||||
return Tables(_ensure_tables(d, dialect=dialect))
|
||||
|
||||
|
||||
def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
|
||||
def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
|
||||
if not d:
|
||||
return {}
|
||||
|
||||
depth = dict_depth(d)
|
||||
|
||||
if depth > 1:
|
||||
return {k: _ensure_tables(v) for k, v in d.items()}
|
||||
return {
|
||||
normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
|
||||
for k, v in d.items()
|
||||
}
|
||||
|
||||
result = {}
|
||||
for name, table in d.items():
|
||||
for table_name, table in d.items():
|
||||
table_name = normalize_name(table_name, dialect=dialect)
|
||||
|
||||
if isinstance(table, Table):
|
||||
result[name] = table
|
||||
result[table_name] = table
|
||||
else:
|
||||
columns = tuple(table[0]) if table else ()
|
||||
rows = [tuple(row[c] for c in columns) for row in table]
|
||||
result[name] = Table(columns=columns, rows=rows)
|
||||
table = [
|
||||
{
|
||||
normalize_name(column_name, dialect=dialect): value
|
||||
for column_name, value in row.items()
|
||||
}
|
||||
for row in table
|
||||
]
|
||||
column_names = tuple(column_name for column_name in table[0]) if table else ()
|
||||
rows = [tuple(row[name] for name in column_names) for row in table]
|
||||
result[table_name] = Table(columns=column_names, rows=rows)
|
||||
|
||||
return result
|
||||
|
|
|
@ -3309,6 +3309,7 @@ class Pivot(Expression):
|
|||
"using": False,
|
||||
"group": False,
|
||||
"columns": False,
|
||||
"include_nulls": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -3397,23 +3398,16 @@ class DataType(Expression):
|
|||
BOOLEAN = auto()
|
||||
CHAR = auto()
|
||||
DATE = auto()
|
||||
DATEMULTIRANGE = auto()
|
||||
DATERANGE = auto()
|
||||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
ENUM = auto()
|
||||
INT4RANGE = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
INT8MULTIRANGE = auto()
|
||||
NUMRANGE = auto()
|
||||
NUMMULTIRANGE = auto()
|
||||
TSRANGE = auto()
|
||||
TSMULTIRANGE = auto()
|
||||
TSTZRANGE = auto()
|
||||
TSTZMULTIRANGE = auto()
|
||||
DATERANGE = auto()
|
||||
DATEMULTIRANGE = auto()
|
||||
DECIMAL = auto()
|
||||
DOUBLE = auto()
|
||||
ENUM = auto()
|
||||
ENUM8 = auto()
|
||||
ENUM16 = auto()
|
||||
FIXEDSTRING = auto()
|
||||
FLOAT = auto()
|
||||
GEOGRAPHY = auto()
|
||||
GEOMETRY = auto()
|
||||
|
@ -3421,23 +3415,31 @@ class DataType(Expression):
|
|||
HSTORE = auto()
|
||||
IMAGE = auto()
|
||||
INET = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
INT = auto()
|
||||
INT128 = auto()
|
||||
INT256 = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT4RANGE = auto()
|
||||
INT8MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
INTERVAL = auto()
|
||||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
JSON = auto()
|
||||
JSONB = auto()
|
||||
LONGBLOB = auto()
|
||||
LONGTEXT = auto()
|
||||
LOWCARDINALITY = auto()
|
||||
MAP = auto()
|
||||
MEDIUMBLOB = auto()
|
||||
MEDIUMTEXT = auto()
|
||||
MONEY = auto()
|
||||
NCHAR = auto()
|
||||
NESTED = auto()
|
||||
NULL = auto()
|
||||
NULLABLE = auto()
|
||||
NUMMULTIRANGE = auto()
|
||||
NUMRANGE = auto()
|
||||
NVARCHAR = auto()
|
||||
OBJECT = auto()
|
||||
ROWVERSION = auto()
|
||||
|
@ -3450,19 +3452,24 @@ class DataType(Expression):
|
|||
SUPER = auto()
|
||||
TEXT = auto()
|
||||
TIME = auto()
|
||||
TIMETZ = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TINYINT = auto()
|
||||
TSMULTIRANGE = auto()
|
||||
TSRANGE = auto()
|
||||
TSTZMULTIRANGE = auto()
|
||||
TSTZRANGE = auto()
|
||||
UBIGINT = auto()
|
||||
UINT = auto()
|
||||
USMALLINT = auto()
|
||||
UTINYINT = auto()
|
||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||
UINT128 = auto()
|
||||
UINT256 = auto()
|
||||
UNIQUEIDENTIFIER = auto()
|
||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||
USERDEFINED = "USER-DEFINED"
|
||||
USMALLINT = auto()
|
||||
UTINYINT = auto()
|
||||
UUID = auto()
|
||||
VARBINARY = auto()
|
||||
VARCHAR = auto()
|
||||
|
@ -3495,6 +3502,7 @@ class DataType(Expression):
|
|||
|
||||
TEMPORAL_TYPES = {
|
||||
Type.TIME,
|
||||
Type.TIMETZ,
|
||||
Type.TIMESTAMP,
|
||||
Type.TIMESTAMPTZ,
|
||||
Type.TIMESTAMPLTZ,
|
||||
|
@ -3858,6 +3866,18 @@ class TimeUnit(Expression):
|
|||
super().__init__(**args)
|
||||
|
||||
|
||||
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
|
||||
# https://trino.io/docs/current/language/types.html#interval-year-to-month
|
||||
class IntervalYearToMonthSpan(Expression):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
|
||||
# https://trino.io/docs/current/language/types.html#interval-day-to-second
|
||||
class IntervalDayToSecondSpan(Expression):
|
||||
arg_types = {}
|
||||
|
||||
|
||||
class Interval(TimeUnit):
|
||||
arg_types = {"this": False, "unit": False}
|
||||
|
||||
|
|
|
@ -71,6 +71,8 @@ class Generator:
|
|||
exp.ExternalProperty: lambda self, e: "EXTERNAL",
|
||||
exp.HeapProperty: lambda self, e: "HEAP",
|
||||
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
|
||||
exp.IntervalDayToSecondSpan: "DAY TO SECOND",
|
||||
exp.IntervalYearToMonthSpan: "YEAR TO MONTH",
|
||||
exp.LanguageProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LocationProperty: lambda self, e: self.naked_property(e),
|
||||
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
|
||||
|
@ -166,6 +168,9 @@ class Generator:
|
|||
# Whether or not to generate an unquoted value for EXTRACT's date part argument
|
||||
EXTRACT_ALLOWS_QUOTES = True
|
||||
|
||||
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
|
||||
TZ_TO_WITH_TIME_ZONE = False
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
|
||||
|
||||
|
@ -271,10 +276,12 @@ class Generator:
|
|||
|
||||
# Expressions whose comments are separated from them for better formatting
|
||||
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Create,
|
||||
exp.Delete,
|
||||
exp.Drop,
|
||||
exp.From,
|
||||
exp.Insert,
|
||||
exp.Join,
|
||||
exp.Select,
|
||||
exp.Update,
|
||||
exp.Where,
|
||||
|
@ -831,14 +838,17 @@ class Generator:
|
|||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
type_value = expression.this
|
||||
|
||||
type_sql = (
|
||||
self.TYPE_MAPPING.get(type_value, type_value.value)
|
||||
if isinstance(type_value, exp.DataType.Type)
|
||||
else type_value
|
||||
)
|
||||
|
||||
nested = ""
|
||||
interior = self.expressions(expression, flat=True)
|
||||
values = ""
|
||||
|
||||
if interior:
|
||||
if expression.args.get("nested"):
|
||||
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
|
||||
|
@ -846,10 +856,19 @@ class Generator:
|
|||
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
|
||||
values = self.expressions(expression, key="values", flat=True)
|
||||
values = f"{delimiters[0]}{values}{delimiters[1]}"
|
||||
elif type_value == exp.DataType.Type.INTERVAL:
|
||||
nested = f" {interior}"
|
||||
else:
|
||||
nested = f"({interior})"
|
||||
|
||||
return f"{type_sql}{nested}{values}"
|
||||
type_sql = f"{type_sql}{nested}{values}"
|
||||
if self.TZ_TO_WITH_TIME_ZONE and type_value in (
|
||||
exp.DataType.Type.TIMETZ,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
):
|
||||
type_sql = f"{type_sql} WITH TIME ZONE"
|
||||
|
||||
return type_sql
|
||||
|
||||
def directory_sql(self, expression: exp.Directory) -> str:
|
||||
local = "LOCAL " if expression.args.get("local") else ""
|
||||
|
@ -1288,7 +1307,12 @@ class Generator:
|
|||
unpivot = expression.args.get("unpivot")
|
||||
direction = "UNPIVOT" if unpivot else "PIVOT"
|
||||
field = self.sql(expression, "field")
|
||||
return f"{direction}({expressions} FOR {field}){alias}"
|
||||
include_nulls = expression.args.get("include_nulls")
|
||||
if include_nulls is not None:
|
||||
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
|
||||
else:
|
||||
nulls = ""
|
||||
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
|
||||
|
||||
def tuple_sql(self, expression: exp.Tuple) -> str:
|
||||
return f"({self.expressions(expression, flat=True)})"
|
||||
|
|
|
@ -54,11 +54,17 @@ def simplify(expression):
|
|||
def _simplify(expression, root=True):
|
||||
if expression.meta.get(FINAL):
|
||||
return expression
|
||||
|
||||
# Pre-order transformations
|
||||
node = expression
|
||||
node = rewrite_between(node)
|
||||
node = uniq_sort(node, generate, root)
|
||||
node = absorb_and_eliminate(node, root)
|
||||
node = simplify_concat(node)
|
||||
|
||||
exp.replace_children(node, lambda e: _simplify(e, False))
|
||||
|
||||
# Post-order transformations
|
||||
node = simplify_not(node)
|
||||
node = flatten(node)
|
||||
node = simplify_connectors(node, root)
|
||||
|
@ -66,8 +72,11 @@ def simplify(expression):
|
|||
node.parent = expression.parent
|
||||
node = simplify_literals(node, root)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_coalesce(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
||||
return node
|
||||
|
||||
expression = while_changing(expression, _simplify)
|
||||
|
@ -184,6 +193,7 @@ COMPARISONS = (
|
|||
*GT_GTE,
|
||||
exp.EQ,
|
||||
exp.NEQ,
|
||||
exp.Is,
|
||||
)
|
||||
|
||||
INVERSE_COMPARISONS = {
|
||||
|
@ -430,6 +440,103 @@ def simplify_parens(expression):
|
|||
return expression
|
||||
|
||||
|
||||
CONSTANTS = (
|
||||
exp.Literal,
|
||||
exp.Boolean,
|
||||
exp.Null,
|
||||
)
|
||||
|
||||
|
||||
def simplify_coalesce(expression):
|
||||
# COALESCE(x) -> x
|
||||
if (
|
||||
isinstance(expression, exp.Coalesce)
|
||||
and not expression.expressions
|
||||
# COALESCE is also used as a Spark partitioning hint
|
||||
and not isinstance(expression.parent, exp.Hint)
|
||||
):
|
||||
return expression.this
|
||||
|
||||
if not isinstance(expression, COMPARISONS):
|
||||
return expression
|
||||
|
||||
if isinstance(expression.left, exp.Coalesce):
|
||||
coalesce = expression.left
|
||||
other = expression.right
|
||||
elif isinstance(expression.right, exp.Coalesce):
|
||||
coalesce = expression.right
|
||||
other = expression.left
|
||||
else:
|
||||
return expression
|
||||
|
||||
# This transformation is valid for non-constants,
|
||||
# but it really only does anything if they are both constants.
|
||||
if not isinstance(other, CONSTANTS):
|
||||
return expression
|
||||
|
||||
# Find the first constant arg
|
||||
for arg_index, arg in enumerate(coalesce.expressions):
|
||||
if isinstance(arg, CONSTANTS):
|
||||
break
|
||||
else:
|
||||
return expression
|
||||
|
||||
coalesce.set("expressions", coalesce.expressions[:arg_index])
|
||||
|
||||
# Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
|
||||
# since we already remove COALESCE at the top of this function.
|
||||
coalesce = coalesce if coalesce.expressions else coalesce.this
|
||||
|
||||
# This expression is more complex than when we started, but it will get simplified further
|
||||
return exp.or_(
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()).not_(copy=False),
|
||||
expression.copy(),
|
||||
copy=False,
|
||||
),
|
||||
exp.and_(
|
||||
coalesce.is_(exp.null()),
|
||||
type(expression)(this=arg.copy(), expression=other.copy()),
|
||||
copy=False,
|
||||
),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
|
||||
CONCATS = (exp.Concat, exp.DPipe)
|
||||
SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
|
||||
|
||||
|
||||
def simplify_concat(expression):
|
||||
"""Reduces all groups that contain string literals by concatenating them."""
|
||||
if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
|
||||
return expression
|
||||
|
||||
new_args = []
|
||||
for is_string_group, group in itertools.groupby(
|
||||
expression.expressions or expression.flatten(), lambda e: e.is_string
|
||||
):
|
||||
if is_string_group:
|
||||
new_args.append(exp.Literal.string("".join(string.name for string in group)))
|
||||
else:
|
||||
new_args.extend(group)
|
||||
|
||||
# Ensures we preserve the right concat type, i.e. whether it's "safe" or not
|
||||
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
|
||||
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
|
||||
|
||||
|
||||
# CROSS joins result in an empty table if the right table is empty.
|
||||
# So we can only simplify certain types of joins to CROSS.
|
||||
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
|
||||
JOINS = {
|
||||
("", ""),
|
||||
("", "INNER"),
|
||||
("RIGHT", ""),
|
||||
("RIGHT", "OUTER"),
|
||||
}
|
||||
|
||||
|
||||
def remove_where_true(expression):
|
||||
for where in expression.find_all(exp.Where):
|
||||
if always_true(where.this):
|
||||
|
@ -439,6 +546,7 @@ def remove_where_true(expression):
|
|||
always_true(join.args.get("on"))
|
||||
and not join.args.get("using")
|
||||
and not join.args.get("method")
|
||||
and (join.side, join.kind) in JOINS
|
||||
):
|
||||
join.set("on", None)
|
||||
join.set("side", None)
|
||||
|
|
|
@ -102,15 +102,23 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.CURRENT_USER: exp.CurrentUser,
|
||||
}
|
||||
|
||||
STRUCT_TYPE_TOKENS = {
|
||||
TokenType.NESTED,
|
||||
TokenType.STRUCT,
|
||||
}
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
TokenType.ARRAY,
|
||||
TokenType.LOWCARDINALITY,
|
||||
TokenType.MAP,
|
||||
TokenType.NULLABLE,
|
||||
TokenType.STRUCT,
|
||||
*STRUCT_TYPE_TOKENS,
|
||||
}
|
||||
|
||||
ENUM_TYPE_TOKENS = {
|
||||
TokenType.ENUM,
|
||||
TokenType.ENUM8,
|
||||
TokenType.ENUM16,
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
|
@ -128,6 +136,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.UINT128,
|
||||
TokenType.INT256,
|
||||
TokenType.UINT256,
|
||||
TokenType.FIXEDSTRING,
|
||||
TokenType.FLOAT,
|
||||
TokenType.DOUBLE,
|
||||
TokenType.CHAR,
|
||||
|
@ -145,6 +154,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.JSONB,
|
||||
TokenType.INTERVAL,
|
||||
TokenType.TIME,
|
||||
TokenType.TIMETZ,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.TIMESTAMPLTZ,
|
||||
|
@ -187,7 +197,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.INET,
|
||||
TokenType.IPADDRESS,
|
||||
TokenType.IPPREFIX,
|
||||
TokenType.ENUM,
|
||||
*ENUM_TYPE_TOKENS,
|
||||
*NESTED_TYPE_TOKENS,
|
||||
}
|
||||
|
||||
|
@ -384,11 +394,16 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.STAR: exp.Mul,
|
||||
}
|
||||
|
||||
TIMESTAMPS = {
|
||||
TIMES = {
|
||||
TokenType.TIME,
|
||||
TokenType.TIMETZ,
|
||||
}
|
||||
|
||||
TIMESTAMPS = {
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.TIMESTAMPLTZ,
|
||||
*TIMES,
|
||||
}
|
||||
|
||||
SET_OPERATIONS = {
|
||||
|
@ -1165,6 +1180,8 @@ class Parser(metaclass=_Parser):
|
|||
def _parse_create(self) -> exp.Create | exp.Command:
|
||||
# Note: this can't be None because we've matched a statement parser
|
||||
start = self._prev
|
||||
comments = self._prev_comments
|
||||
|
||||
replace = start.text.upper() == "REPLACE" or self._match_pair(
|
||||
TokenType.OR, TokenType.REPLACE
|
||||
)
|
||||
|
@ -1273,6 +1290,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(
|
||||
exp.Create,
|
||||
comments=comments,
|
||||
this=this,
|
||||
kind=create_token.text,
|
||||
replace=replace,
|
||||
|
@ -2338,7 +2356,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
kwargs["this"].set("joins", joins)
|
||||
|
||||
return self.expression(exp.Join, **kwargs)
|
||||
comments = [c for token in (method, side, kind) if token for c in token.comments]
|
||||
return self.expression(exp.Join, comments=comments, **kwargs)
|
||||
|
||||
def _parse_index(
|
||||
self,
|
||||
|
@ -2619,11 +2638,18 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
|
||||
index = self._index
|
||||
include_nulls = None
|
||||
|
||||
if self._match(TokenType.PIVOT):
|
||||
unpivot = False
|
||||
elif self._match(TokenType.UNPIVOT):
|
||||
unpivot = True
|
||||
|
||||
# https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax
|
||||
if self._match_text_seq("INCLUDE", "NULLS"):
|
||||
include_nulls = True
|
||||
elif self._match_text_seq("EXCLUDE", "NULLS"):
|
||||
include_nulls = False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -2654,7 +2680,13 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._match_r_paren()
|
||||
|
||||
pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
|
||||
pivot = self.expression(
|
||||
exp.Pivot,
|
||||
expressions=expressions,
|
||||
field=field,
|
||||
unpivot=unpivot,
|
||||
include_nulls=include_nulls,
|
||||
)
|
||||
|
||||
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
|
||||
pivot.set("alias", self._parse_table_alias())
|
||||
|
@ -3096,7 +3128,7 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.PseudoType, this=self._prev.text)
|
||||
|
||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token == TokenType.STRUCT
|
||||
is_struct = type_token in self.STRUCT_TYPE_TOKENS
|
||||
expressions = None
|
||||
maybe_func = False
|
||||
|
||||
|
@ -3108,7 +3140,7 @@ class Parser(metaclass=_Parser):
|
|||
lambda: self._parse_types(check_func=check_func, schema=schema)
|
||||
)
|
||||
elif type_token in self.ENUM_TYPE_TOKENS:
|
||||
expressions = self._parse_csv(self._parse_primary)
|
||||
expressions = self._parse_csv(self._parse_equality)
|
||||
else:
|
||||
expressions = self._parse_csv(self._parse_type_size)
|
||||
|
||||
|
@ -3118,29 +3150,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
maybe_func = True
|
||||
|
||||
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type.ARRAY,
|
||||
expressions=[
|
||||
exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value],
|
||||
expressions=expressions,
|
||||
nested=nested,
|
||||
)
|
||||
],
|
||||
nested=True,
|
||||
)
|
||||
|
||||
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
|
||||
|
||||
return this
|
||||
|
||||
if self._match(TokenType.L_BRACKET):
|
||||
self._retreat(index)
|
||||
return None
|
||||
|
||||
this: t.Optional[exp.Expression] = None
|
||||
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
|
||||
|
||||
if nested and self._match(TokenType.LT):
|
||||
if is_struct:
|
||||
expressions = self._parse_csv(self._parse_struct_types)
|
||||
|
@ -3156,23 +3168,35 @@ class Parser(metaclass=_Parser):
|
|||
values = self._parse_csv(self._parse_conjunction)
|
||||
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
|
||||
|
||||
value: t.Optional[exp.Expression] = None
|
||||
if type_token in self.TIMESTAMPS:
|
||||
if self._match_text_seq("WITH", "TIME", "ZONE"):
|
||||
maybe_func = False
|
||||
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
|
||||
tz_type = (
|
||||
exp.DataType.Type.TIMETZ
|
||||
if type_token in self.TIMES
|
||||
else exp.DataType.Type.TIMESTAMPTZ
|
||||
)
|
||||
this = exp.DataType(this=tz_type, expressions=expressions)
|
||||
elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
|
||||
maybe_func = False
|
||||
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
|
||||
this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
|
||||
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
|
||||
maybe_func = False
|
||||
elif type_token == TokenType.INTERVAL:
|
||||
unit = self._parse_var()
|
||||
|
||||
if not unit:
|
||||
value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL)
|
||||
if self._match_text_seq("YEAR", "TO", "MONTH"):
|
||||
span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()]
|
||||
elif self._match_text_seq("DAY", "TO", "SECOND"):
|
||||
span = [exp.IntervalDayToSecondSpan()]
|
||||
else:
|
||||
value = self.expression(exp.Interval, unit=unit)
|
||||
span = None
|
||||
|
||||
unit = not span and self._parse_var()
|
||||
if not unit:
|
||||
this = self.expression(
|
||||
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
|
||||
)
|
||||
else:
|
||||
this = self.expression(exp.Interval, unit=unit)
|
||||
|
||||
if maybe_func and check_func:
|
||||
index2 = self._index
|
||||
|
@ -3184,16 +3208,19 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
self._retreat(index2)
|
||||
|
||||
if value:
|
||||
return value
|
||||
if not this:
|
||||
this = exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value],
|
||||
expressions=expressions,
|
||||
nested=nested,
|
||||
values=values,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
return exp.DataType(
|
||||
this=exp.DataType.Type[type_token.value],
|
||||
expressions=expressions,
|
||||
nested=nested,
|
||||
values=values,
|
||||
prefix=prefix,
|
||||
)
|
||||
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
|
||||
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_type() or self._parse_id_var()
|
||||
|
@ -3738,6 +3765,7 @@ class Parser(metaclass=_Parser):
|
|||
ifs = []
|
||||
default = None
|
||||
|
||||
comments = self._prev_comments
|
||||
expression = self._parse_conjunction()
|
||||
|
||||
while self._match(TokenType.WHEN):
|
||||
|
@ -3753,7 +3781,7 @@ class Parser(metaclass=_Parser):
|
|||
self.raise_error("Expected END after CASE", self._prev)
|
||||
|
||||
return self._parse_window(
|
||||
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
|
||||
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
|
||||
)
|
||||
|
||||
def _parse_if(self) -> t.Optional[exp.Expression]:
|
||||
|
|
|
@ -372,21 +372,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
is_table: bool = False,
|
||||
normalize: t.Optional[bool] = None,
|
||||
) -> str:
|
||||
dialect = dialect or self.dialect
|
||||
normalize = self.normalize if normalize is None else normalize
|
||||
|
||||
try:
|
||||
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
|
||||
except ParseError:
|
||||
return name if isinstance(name, str) else name.name
|
||||
|
||||
name = identifier.name
|
||||
if not normalize:
|
||||
return name
|
||||
|
||||
# This can be useful for normalize_identifier
|
||||
identifier.meta["is_table"] = is_table
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
|
||||
return normalize_name(
|
||||
name,
|
||||
dialect=dialect or self.dialect,
|
||||
is_table=is_table,
|
||||
normalize=self.normalize if normalize is None else normalize,
|
||||
)
|
||||
|
||||
def depth(self) -> int:
|
||||
if not self.empty and not self._depth:
|
||||
|
@ -418,6 +409,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
return self._type_mapping_cache[schema_type]
|
||||
|
||||
|
||||
def normalize_name(
|
||||
name: str | exp.Identifier,
|
||||
dialect: DialectType = None,
|
||||
is_table: bool = False,
|
||||
normalize: t.Optional[bool] = True,
|
||||
) -> str:
|
||||
try:
|
||||
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
|
||||
except ParseError:
|
||||
return name if isinstance(name, str) else name.name
|
||||
|
||||
name = identifier.name
|
||||
if not normalize:
|
||||
return name
|
||||
|
||||
# This can be useful for normalize_identifier
|
||||
identifier.meta["is_table"] = is_table
|
||||
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
|
||||
|
||||
|
||||
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
|
||||
if isinstance(schema, Schema):
|
||||
return schema
|
||||
|
|
|
@ -110,6 +110,7 @@ class TokenType(AutoName):
|
|||
JSON = auto()
|
||||
JSONB = auto()
|
||||
TIME = auto()
|
||||
TIMETZ = auto()
|
||||
TIMESTAMP = auto()
|
||||
TIMESTAMPTZ = auto()
|
||||
TIMESTAMPLTZ = auto()
|
||||
|
@ -151,6 +152,11 @@ class TokenType(AutoName):
|
|||
IPADDRESS = auto()
|
||||
IPPREFIX = auto()
|
||||
ENUM = auto()
|
||||
ENUM8 = auto()
|
||||
ENUM16 = auto()
|
||||
FIXEDSTRING = auto()
|
||||
LOWCARDINALITY = auto()
|
||||
NESTED = auto()
|
||||
|
||||
# keywords
|
||||
ALIAS = auto()
|
||||
|
@ -659,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"TINYINT": TokenType.TINYINT,
|
||||
"SHORT": TokenType.SMALLINT,
|
||||
"SMALLINT": TokenType.SMALLINT,
|
||||
"INT128": TokenType.INT128,
|
||||
"INT2": TokenType.SMALLINT,
|
||||
"INTEGER": TokenType.INT,
|
||||
"INT": TokenType.INT,
|
||||
|
@ -699,6 +706,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"BYTEA": TokenType.VARBINARY,
|
||||
"VARBINARY": TokenType.VARBINARY,
|
||||
"TIME": TokenType.TIME,
|
||||
"TIMETZ": TokenType.TIMETZ,
|
||||
"TIMESTAMP": TokenType.TIMESTAMP,
|
||||
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
|
||||
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
|
||||
|
@ -879,6 +887,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
|
||||
self._prev_token_line = self._line
|
||||
|
||||
if self._comments and token_type == TokenType.SEMICOLON and self.tokens:
|
||||
self.tokens[-1].comments.extend(self._comments)
|
||||
self._comments = []
|
||||
|
||||
self.tokens.append(
|
||||
Token(
|
||||
token_type,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue