1
0
Fork 0

Merging upstream version 17.12.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 20:55:29 +01:00
parent aa315e6009
commit aae08e0bb3
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
64 changed files with 12465 additions and 11885 deletions

View file

@ -60,6 +60,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.doris import Doris
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive

View file

@ -37,17 +37,22 @@ class ClickHouse(Dialect):
"ATTACH": TokenType.COMMAND,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
"ENUM": TokenType.ENUM,
"ENUM8": TokenType.ENUM8,
"ENUM16": TokenType.ENUM16,
"FINAL": TokenType.FINAL,
"FIXEDSTRING": TokenType.FIXEDSTRING,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"GLOBAL": TokenType.GLOBAL,
"INT128": TokenType.INT128,
"INT16": TokenType.SMALLINT,
"INT256": TokenType.INT256,
"INT32": TokenType.INT,
"INT64": TokenType.BIGINT,
"INT8": TokenType.TINYINT,
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
"MAP": TokenType.MAP,
"NESTED": TokenType.NESTED,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"UINT16": TokenType.USMALLINT,
@ -294,11 +299,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.ENUM: "Enum",
exp.DataType.Type.ENUM8: "Enum8",
exp.DataType.Type.ENUM16: "Enum16",
exp.DataType.Type.FIXEDSTRING: "FixedString",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.LOWCARDINALITY: "LowCardinality",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.NESTED: "Nested",
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",

View file

@ -39,6 +39,7 @@ class Dialects(str, Enum):
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
Doris = "doris"
class _Dialect(type):
@ -121,7 +122,7 @@ class _Dialect(type):
if hasattr(subclass, name):
setattr(subclass, name, value)
if not klass.STRICT_STRING_CONCAT:
if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
klass.generator_class.can_identify = klass.can_identify
@ -146,6 +147,9 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
# Determines whether or not the DPIPE token ('||') is a string concatenation operator
DPIPE_IS_STRING_CONCAT = True
# Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False
@ -460,6 +464,20 @@ def format_time_lambda(
return _format_time
def time_format(
dialect: DialectType = None,
) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]:
def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]:
"""
Returns the time format for a given expression, unless it's equivalent
to the default time format of the dialect of interest.
"""
time_format = self.format_time(expression)
return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None
return _time_format
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
@ -699,3 +717,8 @@ def simplify_literal(expression: E) -> E:
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))

65
sqlglot/dialects/doris.py Normal file
View file

@ -0,0 +1,65 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
parse_timestamp_trunc,
rename_func,
time_format,
)
from sqlglot.dialects.mysql import MySQL
class Doris(MySQL):
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
}
class Generator(MySQL.Generator):
CAST_MAPPING = {}
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.Map: rename_func("ARRAY_MAP"),
}

View file

@ -89,6 +89,11 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
# Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers
if expression.is_type("timestamptz", "timetz"):
return expression.this.value
return self.datatype_sql(expression)
@ -110,14 +115,14 @@ class DuckDB(Dialect):
"//": TokenType.DIV,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
"BPCHAR": TokenType.TEXT,
"BITSTRING": TokenType.BIT,
"BPCHAR": TokenType.TEXT,
"CHAR": TokenType.TEXT,
"CHARACTER VARYING": TokenType.TEXT,
"EXCLUDE": TokenType.EXCEPT,
"HUGEINT": TokenType.INT128,
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"NUMERIC": TokenType.DOUBLE,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
@ -186,6 +191,22 @@ class DuckDB(Dialect):
TokenType.UTINYINT,
}
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
this = super()._parse_types(check_func=check_func, schema=schema)
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
# See: https://duckdb.org/docs/sql/data_types/numeric
if (
isinstance(this, exp.DataType)
and this.is_type("numeric", "decimal")
and not this.expressions
):
return exp.DataType.build("DECIMAL(18, 3)")
return this
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
@ -231,6 +252,7 @@ class DuckDB(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False),
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,

View file

@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
right_to_substring_sql,
strposition_to_locate_sql,
struct_extract_sql,
time_format,
timestrtotime_sql,
var_map_sql,
)
@ -113,7 +114,7 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
@ -132,15 +133,6 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
return f"CAST({this} AS TIMESTAMP)"
def _time_format(
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Hive.TIME_FORMAT:
return None
return time_format
def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@ -439,7 +431,7 @@ class Hive(Dialect):
exp.TsOrDsToDate: _to_date_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, _time_format(self, e)
"FROM_UNIXTIME", e.this, time_format("hive")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),

View file

@ -94,6 +94,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
TIME_MAPPING = {
@ -103,7 +104,6 @@ class MySQL(Dialect):
"%h": "%I",
"%i": "%M",
"%s": "%S",
"%S": "%S",
"%u": "%W",
"%k": "%-H",
"%l": "%-I",
@ -196,8 +196,14 @@ class MySQL(Dialect):
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
TokenType.XOR: exp.Xor,
TokenType.DPIPE: exp.Or,
}
# MySQL uses || as a synonym to the logical OR operator
# https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or
BITWISE = parser.Parser.BITWISE.copy()
BITWISE.pop(TokenType.DPIPE)
TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)

View file

@ -16,6 +16,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
simplify_literal,
str_position_sql,
@ -286,9 +287,7 @@ class Postgres(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),

View file

@ -32,13 +32,6 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
if expression.is_type("timestamptz"):
sql = f"{sql} WITH TIME ZONE"
return sql
def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
@ -231,6 +224,7 @@ class Presto(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@ -245,6 +239,7 @@ class Presto(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
}
@ -265,7 +260,6 @@ class Presto(Dialect):
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),

View file

@ -85,8 +85,6 @@ class Redshift(Postgres):
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"TOP": TokenType.TOP,
"UNLOAD": TokenType.COMMAND,
"VARBYTE": TokenType.VARBINARY,
@ -101,12 +99,15 @@ class Redshift(Postgres):
RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "VARBYTE",
}
PROPERTIES_LOCATION = {

View file

@ -52,6 +52,9 @@ class Spark(Spark2):
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
}
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)

View file

@ -4,6 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
parse_timestamp_trunc,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
@ -14,9 +15,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_TRUNC": parse_timestamp_trunc,
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),

View file

@ -28,6 +28,11 @@ if t.TYPE_CHECKING:
from sqlglot.schema import Schema
PYTHON_TYPE_TO_SQLGLOT = {
"dict": "MAP",
}
def execute(
sql: str | Expression,
schema: t.Optional[t.Dict | Schema] = None,
@ -50,7 +55,7 @@ def execute(
Returns:
Simple columnar data structure.
"""
tables_ = ensure_tables(tables)
tables_ = ensure_tables(tables, dialect=read)
if not schema:
schema = {}
@ -61,7 +66,8 @@ def execute(
assert table is not None
for column in table.columns:
nested_set(schema, [*keys, column], type(table[0][column]).__name__)
py_type = type(table[0][column]).__name__
nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type)
schema = ensure_schema(schema, dialect=read)

View file

@ -2,8 +2,9 @@ from __future__ import annotations
import typing as t
from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import dict_depth
from sqlglot.schema import AbstractMappingSchema
from sqlglot.schema import AbstractMappingSchema, normalize_name
class Table:
@ -108,26 +109,37 @@ class Tables(AbstractMappingSchema[Table]):
pass
def ensure_tables(d: t.Optional[t.Dict]) -> Tables:
return Tables(_ensure_tables(d))
def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables:
return Tables(_ensure_tables(d, dialect=dialect))
def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict:
def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict:
if not d:
return {}
depth = dict_depth(d)
if depth > 1:
return {k: _ensure_tables(v) for k, v in d.items()}
return {
normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect)
for k, v in d.items()
}
result = {}
for name, table in d.items():
for table_name, table in d.items():
table_name = normalize_name(table_name, dialect=dialect)
if isinstance(table, Table):
result[name] = table
result[table_name] = table
else:
columns = tuple(table[0]) if table else ()
rows = [tuple(row[c] for c in columns) for row in table]
result[name] = Table(columns=columns, rows=rows)
table = [
{
normalize_name(column_name, dialect=dialect): value
for column_name, value in row.items()
}
for row in table
]
column_names = tuple(column_name for column_name in table[0]) if table else ()
rows = [tuple(row[name] for name in column_names) for row in table]
result[table_name] = Table(columns=column_names, rows=rows)
return result

View file

@ -3309,6 +3309,7 @@ class Pivot(Expression):
"using": False,
"group": False,
"columns": False,
"include_nulls": False,
}
@ -3397,23 +3398,16 @@ class DataType(Expression):
BOOLEAN = auto()
CHAR = auto()
DATE = auto()
DATEMULTIRANGE = auto()
DATERANGE = auto()
DATETIME = auto()
DATETIME64 = auto()
ENUM = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
INT8MULTIRANGE = auto()
NUMRANGE = auto()
NUMMULTIRANGE = auto()
TSRANGE = auto()
TSMULTIRANGE = auto()
TSTZRANGE = auto()
TSTZMULTIRANGE = auto()
DATERANGE = auto()
DATEMULTIRANGE = auto()
DECIMAL = auto()
DOUBLE = auto()
ENUM = auto()
ENUM8 = auto()
ENUM16 = auto()
FIXEDSTRING = auto()
FLOAT = auto()
GEOGRAPHY = auto()
GEOMETRY = auto()
@ -3421,23 +3415,31 @@ class DataType(Expression):
HSTORE = auto()
IMAGE = auto()
INET = auto()
IPADDRESS = auto()
IPPREFIX = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
INT4MULTIRANGE = auto()
INT4RANGE = auto()
INT8MULTIRANGE = auto()
INT8RANGE = auto()
INTERVAL = auto()
IPADDRESS = auto()
IPPREFIX = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
LONGTEXT = auto()
LOWCARDINALITY = auto()
MAP = auto()
MEDIUMBLOB = auto()
MEDIUMTEXT = auto()
MONEY = auto()
NCHAR = auto()
NESTED = auto()
NULL = auto()
NULLABLE = auto()
NUMMULTIRANGE = auto()
NUMRANGE = auto()
NVARCHAR = auto()
OBJECT = auto()
ROWVERSION = auto()
@ -3450,19 +3452,24 @@ class DataType(Expression):
SUPER = auto()
TEXT = auto()
TIME = auto()
TIMETZ = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
TIMESTAMPTZ = auto()
TINYINT = auto()
TSMULTIRANGE = auto()
TSRANGE = auto()
TSTZMULTIRANGE = auto()
TSTZRANGE = auto()
UBIGINT = auto()
UINT = auto()
USMALLINT = auto()
UTINYINT = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
UINT128 = auto()
UINT256 = auto()
UNIQUEIDENTIFIER = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
USERDEFINED = "USER-DEFINED"
USMALLINT = auto()
UTINYINT = auto()
UUID = auto()
VARBINARY = auto()
VARCHAR = auto()
@ -3495,6 +3502,7 @@ class DataType(Expression):
TEMPORAL_TYPES = {
Type.TIME,
Type.TIMETZ,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
@ -3858,6 +3866,18 @@ class TimeUnit(Expression):
super().__init__(**args)
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-year-to-month
class IntervalYearToMonthSpan(Expression):
arg_types = {}
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-day-to-second
class IntervalDayToSecondSpan(Expression):
arg_types = {}
class Interval(TimeUnit):
arg_types = {"this": False, "unit": False}

View file

@ -71,6 +71,8 @@ class Generator:
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.IntervalDayToSecondSpan: "DAY TO SECOND",
exp.IntervalYearToMonthSpan: "YEAR TO MONTH",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
@ -166,6 +168,9 @@ class Generator:
# Whether or not to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
# Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@ -271,10 +276,12 @@ class Generator:
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Create,
exp.Delete,
exp.Drop,
exp.From,
exp.Insert,
exp.Join,
exp.Select,
exp.Update,
exp.Where,
@ -831,14 +838,17 @@ class Generator:
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
nested = ""
interior = self.expressions(expression, flat=True)
values = ""
if interior:
if expression.args.get("nested"):
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
@ -846,10 +856,19 @@ class Generator:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
values = self.expressions(expression, key="values", flat=True)
values = f"{delimiters[0]}{values}{delimiters[1]}"
elif type_value == exp.DataType.Type.INTERVAL:
nested = f" {interior}"
else:
nested = f"({interior})"
return f"{type_sql}{nested}{values}"
type_sql = f"{type_sql}{nested}{values}"
if self.TZ_TO_WITH_TIME_ZONE and type_value in (
exp.DataType.Type.TIMETZ,
exp.DataType.Type.TIMESTAMPTZ,
):
type_sql = f"{type_sql} WITH TIME ZONE"
return type_sql
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@ -1288,7 +1307,12 @@ class Generator:
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
field = self.sql(expression, "field")
return f"{direction}({expressions} FOR {field}){alias}"
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS "
else:
nulls = ""
return f"{direction}{nulls}({expressions} FOR {field}){alias}"
def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"

View file

@ -54,11 +54,17 @@ def simplify(expression):
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
exp.replace_children(node, lambda e: _simplify(e, False))
# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
@ -66,8 +72,11 @@ def simplify(expression):
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_parens(node)
node = simplify_coalesce(node)
if root:
expression.replace(node)
return node
expression = while_changing(expression, _simplify)
@ -184,6 +193,7 @@ COMPARISONS = (
*GT_GTE,
exp.EQ,
exp.NEQ,
exp.Is,
)
INVERSE_COMPARISONS = {
@ -430,6 +440,103 @@ def simplify_parens(expression):
return expression
CONSTANTS = (
exp.Literal,
exp.Boolean,
exp.Null,
)
def simplify_coalesce(expression):
# COALESCE(x) -> x
if (
isinstance(expression, exp.Coalesce)
and not expression.expressions
# COALESCE is also used as a Spark partitioning hint
and not isinstance(expression.parent, exp.Hint)
):
return expression.this
if not isinstance(expression, COMPARISONS):
return expression
if isinstance(expression.left, exp.Coalesce):
coalesce = expression.left
other = expression.right
elif isinstance(expression.right, exp.Coalesce):
coalesce = expression.right
other = expression.left
else:
return expression
# This transformation is valid for non-constants,
# but it really only does anything if they are both constants.
if not isinstance(other, CONSTANTS):
return expression
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
if isinstance(arg, CONSTANTS):
break
else:
return expression
coalesce.set("expressions", coalesce.expressions[:arg_index])
# Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
# since we already remove COALESCE at the top of this function.
coalesce = coalesce if coalesce.expressions else coalesce.this
# This expression is more complex than when we started, but it will get simplified further
return exp.or_(
exp.and_(
coalesce.is_(exp.null()).not_(copy=False),
expression.copy(),
copy=False,
),
exp.and_(
coalesce.is_(exp.null()),
type(expression)(this=arg.copy(), expression=other.copy()),
copy=False,
),
copy=False,
)
CONCATS = (exp.Concat, exp.DPipe)
SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
"""Reduces all groups that contain string literals by concatenating them."""
if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
return expression
new_args = []
for is_string_group, group in itertools.groupby(
expression.expressions or expression.flatten(), lambda e: e.is_string
):
if is_string_group:
new_args.append(exp.Literal.string("".join(string.name for string in group)))
else:
new_args.extend(group)
# Ensures we preserve the right concat type, i.e. whether it's "safe" or not
concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
JOINS = {
("", ""),
("", "INNER"),
("RIGHT", ""),
("RIGHT", "OUTER"),
}
def remove_where_true(expression):
for where in expression.find_all(exp.Where):
if always_true(where.this):
@ -439,6 +546,7 @@ def remove_where_true(expression):
always_true(join.args.get("on"))
and not join.args.get("using")
and not join.args.get("method")
and (join.side, join.kind) in JOINS
):
join.set("on", None)
join.set("side", None)

View file

@ -102,15 +102,23 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
STRUCT_TYPE_TOKENS = {
TokenType.NESTED,
TokenType.STRUCT,
}
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.LOWCARDINALITY,
TokenType.MAP,
TokenType.NULLABLE,
TokenType.STRUCT,
*STRUCT_TYPE_TOKENS,
}
ENUM_TYPE_TOKENS = {
TokenType.ENUM,
TokenType.ENUM8,
TokenType.ENUM16,
}
TYPE_TOKENS = {
@ -128,6 +136,7 @@ class Parser(metaclass=_Parser):
TokenType.UINT128,
TokenType.INT256,
TokenType.UINT256,
TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
TokenType.CHAR,
@ -145,6 +154,7 @@ class Parser(metaclass=_Parser):
TokenType.JSONB,
TokenType.INTERVAL,
TokenType.TIME,
TokenType.TIMETZ,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@ -187,7 +197,7 @@ class Parser(metaclass=_Parser):
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
TokenType.ENUM,
*ENUM_TYPE_TOKENS,
*NESTED_TYPE_TOKENS,
}
@ -384,11 +394,16 @@ class Parser(metaclass=_Parser):
TokenType.STAR: exp.Mul,
}
TIMESTAMPS = {
TIMES = {
TokenType.TIME,
TokenType.TIMETZ,
}
TIMESTAMPS = {
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
*TIMES,
}
SET_OPERATIONS = {
@ -1165,6 +1180,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self) -> exp.Create | exp.Command:
# Note: this can't be None because we've matched a statement parser
start = self._prev
comments = self._prev_comments
replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
@ -1273,6 +1290,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Create,
comments=comments,
this=this,
kind=create_token.text,
replace=replace,
@ -2338,7 +2356,8 @@ class Parser(metaclass=_Parser):
kwargs["this"].set("joins", joins)
return self.expression(exp.Join, **kwargs)
comments = [c for token in (method, side, kind) if token for c in token.comments]
return self.expression(exp.Join, comments=comments, **kwargs)
def _parse_index(
self,
@ -2619,11 +2638,18 @@ class Parser(metaclass=_Parser):
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
include_nulls = None
if self._match(TokenType.PIVOT):
unpivot = False
elif self._match(TokenType.UNPIVOT):
unpivot = True
# https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax
if self._match_text_seq("INCLUDE", "NULLS"):
include_nulls = True
elif self._match_text_seq("EXCLUDE", "NULLS"):
include_nulls = False
else:
return None
@ -2654,7 +2680,13 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
pivot = self.expression(
exp.Pivot,
expressions=expressions,
field=field,
unpivot=unpivot,
include_nulls=include_nulls,
)
if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False):
pivot.set("alias", self._parse_table_alias())
@ -3096,7 +3128,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.PseudoType, this=self._prev.text)
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
is_struct = type_token in self.STRUCT_TYPE_TOKENS
expressions = None
maybe_func = False
@ -3108,7 +3140,7 @@ class Parser(metaclass=_Parser):
lambda: self._parse_types(check_func=check_func, schema=schema)
)
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_primary)
expressions = self._parse_csv(self._parse_equality)
else:
expressions = self._parse_csv(self._parse_type_size)
@ -3118,29 +3150,9 @@ class Parser(metaclass=_Parser):
maybe_func = True
if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[
exp.DataType(
this=exp.DataType.Type[type_token.value],
expressions=expressions,
nested=nested,
)
],
nested=True,
)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
return this
if self._match(TokenType.L_BRACKET):
self._retreat(index)
return None
this: t.Optional[exp.Expression] = None
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
@ -3156,23 +3168,35 @@ class Parser(metaclass=_Parser):
values = self._parse_csv(self._parse_conjunction)
self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match_text_seq("WITH", "TIME", "ZONE"):
maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
tz_type = (
exp.DataType.Type.TIMETZ
if type_token in self.TIMES
else exp.DataType.Type.TIMESTAMPTZ
)
this = exp.DataType(this=tz_type, expressions=expressions)
elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
maybe_func = False
elif type_token == TokenType.INTERVAL:
unit = self._parse_var()
if not unit:
value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL)
if self._match_text_seq("YEAR", "TO", "MONTH"):
span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()]
elif self._match_text_seq("DAY", "TO", "SECOND"):
span = [exp.IntervalDayToSecondSpan()]
else:
value = self.expression(exp.Interval, unit=unit)
span = None
unit = not span and self._parse_var()
if not unit:
this = self.expression(
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
)
else:
this = self.expression(exp.Interval, unit=unit)
if maybe_func and check_func:
index2 = self._index
@ -3184,16 +3208,19 @@ class Parser(metaclass=_Parser):
self._retreat(index2)
if value:
return value
if not this:
this = exp.DataType(
this=exp.DataType.Type[type_token.value],
expressions=expressions,
nested=nested,
values=values,
prefix=prefix,
)
return exp.DataType(
this=exp.DataType.Type[type_token.value],
expressions=expressions,
nested=nested,
values=values,
prefix=prefix,
)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
return this
def _parse_struct_types(self) -> t.Optional[exp.Expression]:
this = self._parse_type() or self._parse_id_var()
@ -3738,6 +3765,7 @@ class Parser(metaclass=_Parser):
ifs = []
default = None
comments = self._prev_comments
expression = self._parse_conjunction()
while self._match(TokenType.WHEN):
@ -3753,7 +3781,7 @@ class Parser(metaclass=_Parser):
self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
)
def _parse_if(self) -> t.Optional[exp.Expression]:

View file

@ -372,21 +372,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
is_table: bool = False,
normalize: t.Optional[bool] = None,
) -> str:
dialect = dialect or self.dialect
normalize = self.normalize if normalize is None else normalize
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
name = identifier.name
if not normalize:
return name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
return normalize_name(
name,
dialect=dialect or self.dialect,
is_table=is_table,
normalize=self.normalize if normalize is None else normalize,
)
def depth(self) -> int:
if not self.empty and not self._depth:
@ -418,6 +409,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return self._type_mapping_cache[schema_type]
def normalize_name(
name: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
) -> str:
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
name = identifier.name
if not normalize:
return name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema

View file

@ -110,6 +110,7 @@ class TokenType(AutoName):
JSON = auto()
JSONB = auto()
TIME = auto()
TIMETZ = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@ -151,6 +152,11 @@ class TokenType(AutoName):
IPADDRESS = auto()
IPPREFIX = auto()
ENUM = auto()
ENUM8 = auto()
ENUM16 = auto()
FIXEDSTRING = auto()
LOWCARDINALITY = auto()
NESTED = auto()
# keywords
ALIAS = auto()
@ -659,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TINYINT": TokenType.TINYINT,
"SHORT": TokenType.SMALLINT,
"SMALLINT": TokenType.SMALLINT,
"INT128": TokenType.INT128,
"INT2": TokenType.SMALLINT,
"INTEGER": TokenType.INT,
"INT": TokenType.INT,
@ -699,6 +706,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
"TIME": TokenType.TIME,
"TIMETZ": TokenType.TIMETZ,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@ -879,6 +887,11 @@ class Tokenizer(metaclass=_Tokenizer):
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
if self._comments and token_type == TokenType.SEMICOLON and self.tokens:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self.tokens.append(
Token(
token_type,