Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
15b8b39545
commit
c37998973e
88 changed files with 52059 additions and 46960 deletions
|
@ -190,6 +190,16 @@ class BigQuery(Dialect):
|
|||
"%D": "%m/%d/%y",
|
||||
}
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\f": "\f",
|
||||
"\\n": "\n",
|
||||
"\\r": "\r",
|
||||
"\\t": "\t",
|
||||
"\\v": "\v",
|
||||
}
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"DD": "%d",
|
||||
"MM": "%m",
|
||||
|
@ -212,15 +222,14 @@ class BigQuery(Dialect):
|
|||
|
||||
@classmethod
|
||||
def normalize_identifier(cls, expression: E) -> E:
|
||||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
# The following check is essentially a heuristic to detect tables based on whether or
|
||||
# not they're qualified.
|
||||
if isinstance(expression, exp.Identifier):
|
||||
parent = expression.parent
|
||||
|
||||
while isinstance(parent, exp.Dot):
|
||||
parent = parent.parent
|
||||
|
||||
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
|
||||
# The following check is essentially a heuristic to detect tables based on whether or
|
||||
# not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive.
|
||||
if (
|
||||
not isinstance(parent, exp.UserDefinedFunction)
|
||||
and not (isinstance(parent, exp.Table) and parent.db)
|
||||
|
@ -419,6 +428,7 @@ class BigQuery(Dialect):
|
|||
RENAME_TABLE_WITH_DB = False
|
||||
NVL2_SUPPORTED = False
|
||||
UNNEST_WITH_ORDINALITY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -520,18 +530,6 @@ class BigQuery(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore
|
||||
{
|
||||
"\a": "\\a",
|
||||
"\b": "\\b",
|
||||
"\f": "\\f",
|
||||
"\n": "\\n",
|
||||
"\r": "\\r",
|
||||
"\t": "\\t",
|
||||
"\v": "\\v",
|
||||
}
|
||||
)
|
||||
|
||||
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
|
||||
RESERVED_KEYWORDS = {
|
||||
*generator.Generator.RESERVED_KEYWORDS,
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
inline_array_sql,
|
||||
|
@ -21,18 +21,33 @@ def _lower_func(sql: str) -> str:
|
|||
return sql[:index].lower() + sql[index:]
|
||||
|
||||
|
||||
def _quantile_sql(self, e):
|
||||
quantile = e.args["quantile"]
|
||||
args = f"({self.sql(e, 'this')})"
|
||||
if isinstance(quantile, exp.Array):
|
||||
func = self.func("quantiles", *quantile)
|
||||
else:
|
||||
func = self.func("quantile", quantile)
|
||||
return func + args
|
||||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
STRICT_STRING_CONCAT = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
|
||||
ESCAPE_SEQUENCES = {
|
||||
"\\0": "\0",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
IDENTIFIERS = ['"', "`"]
|
||||
STRING_ESCAPES = ["'", "\\"]
|
||||
BIT_STRINGS = [("0b", "")]
|
||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||
HEREDOC_STRINGS = ["$"]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -55,6 +70,7 @@ class ClickHouse(Dialect):
|
|||
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
|
||||
"MAP": TokenType.MAP,
|
||||
"NESTED": TokenType.NESTED,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"TUPLE": TokenType.STRUCT,
|
||||
"UINT128": TokenType.UINT128,
|
||||
"UINT16": TokenType.USMALLINT,
|
||||
|
@ -64,6 +80,11 @@ class ClickHouse(Dialect):
|
|||
"UINT8": TokenType.UTINYINT,
|
||||
}
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.HEREDOC_STRING,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -301,6 +322,7 @@ class ClickHouse(Dialect):
|
|||
QUERY_HINTS = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
NVL2_SUPPORTED = False
|
||||
TABLESAMPLE_REQUIRES_PARENS = False
|
||||
|
||||
STRING_TYPE_MAPPING = {
|
||||
exp.DataType.Type.CHAR: "String",
|
||||
|
@ -348,6 +370,7 @@ class ClickHouse(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
|
||||
exp.AnyValue: rename_func("any"),
|
||||
exp.ApproxDistinct: rename_func("uniq"),
|
||||
exp.Array: inline_array_sql,
|
||||
|
@ -359,12 +382,13 @@ class ClickHouse(Dialect):
|
|||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
exp.IsNan: rename_func("isNaN"),
|
||||
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
|
||||
+ f"({self.sql(e, 'this')})",
|
||||
exp.Quantile: _quantile_sql,
|
||||
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
|
||||
exp.StartsWith: rename_func("startsWith"),
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
|
||||
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
|
||||
|
|
|
@ -51,6 +51,26 @@ class Databricks(Spark):
|
|||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
|
||||
constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
|
||||
kind = expression.args.get("kind")
|
||||
if (
|
||||
constraint
|
||||
and isinstance(kind, exp.DataType)
|
||||
and kind.this in exp.DataType.INTEGER_TYPES
|
||||
):
|
||||
# only BIGINT generated identity constraints are supported
|
||||
expression = expression.copy()
|
||||
expression.set("kind", exp.DataType.build("bigint"))
|
||||
return super().columndef_sql(expression, sep)
|
||||
|
||||
def generatedasidentitycolumnconstraint_sql(
|
||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||
) -> str:
|
||||
expression = expression.copy()
|
||||
expression.set("this", True) # trigger ALWAYS in super class
|
||||
return super().generatedasidentitycolumnconstraint_sql(expression)
|
||||
|
||||
class Tokenizer(Spark.Tokenizer):
|
||||
HEX_STRINGS = []
|
||||
|
||||
|
|
|
@ -81,6 +81,8 @@ class _Dialect(type):
|
|||
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
|
||||
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
|
||||
|
||||
klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
|
||||
|
||||
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
|
||||
klass.parser_class = getattr(klass, "Parser", Parser)
|
||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||
|
@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect):
|
|||
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
|
||||
FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
|
||||
# Mapping of an unescaped escape sequence to the corresponding character
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Columns that are auto-generated by the engine corresponding to this dialect
|
||||
# Such columns may be excluded from SELECT * queries, for example
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
|
@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect):
|
|||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
|
||||
|
@ -245,7 +252,7 @@ class Dialect(metaclass=_Dialect):
|
|||
"""
|
||||
Normalizes an unquoted identifier to either lower or upper case, thus essentially
|
||||
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
|
||||
they will be normalized regardless of being quoted or not.
|
||||
they will be normalized to lowercase regardless of being quoted or not.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier) and (
|
||||
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
|
||||
|
|
|
@ -51,6 +51,32 @@ TIME_DIFF_FACTOR = {
|
|||
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
|
||||
|
||||
|
||||
def _create_sql(self, expression: exp.Create) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
# remove UNIQUE column constraints
|
||||
for constraint in expression.find_all(exp.UniqueColumnConstraint):
|
||||
if constraint.parent:
|
||||
constraint.parent.pop()
|
||||
|
||||
properties = expression.args.get("properties")
|
||||
temporary = any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
)
|
||||
|
||||
# CTAS with temp tables map to CREATE TEMPORARY VIEW
|
||||
kind = expression.args["kind"]
|
||||
if kind.upper() == "TABLE" and temporary:
|
||||
if expression.expression:
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
|
||||
else:
|
||||
# CREATE TEMPORARY TABLE may require storage provider
|
||||
expression = self.temporary_storage_provider(expression)
|
||||
|
||||
return create_with_partitions_sql(self, expression)
|
||||
|
||||
|
||||
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
unit = expression.text("unit").upper()
|
||||
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
|
||||
|
@ -429,7 +455,7 @@ class Hive(Dialect):
|
|||
if e.args.get("allow_null")
|
||||
else "NOT NULL",
|
||||
exp.VarMap: var_map_sql,
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.Create: _create_sql,
|
||||
exp.Quantile: rename_func("PERCENTILE"),
|
||||
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
|
||||
exp.RegexpExtract: regexp_extract_sql,
|
||||
|
@ -478,8 +504,13 @@ class Hive(Dialect):
|
|||
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# Hive has no temporary storage provider (there are hive settings though)
|
||||
return expression
|
||||
|
||||
def parameter_sql(self, expression: exp.Parameter) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
parent = expression.parent
|
||||
|
|
|
@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
|
|||
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
|
||||
|
||||
|
||||
def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
|
||||
def _str_to_date_sql(
|
||||
self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
|
||||
) -> str:
|
||||
date_format = self.format_time(expression)
|
||||
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
|
||||
|
||||
|
@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
|
|||
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
def _date_add_sql(
|
||||
kind: str,
|
||||
) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
|
||||
|
@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D
|
|||
return func
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = expression.args.get("format")
|
||||
if time_format:
|
||||
return _str_to_date_sql(self, expression)
|
||||
return f"DATE({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _remove_ts_or_ds_to_date(
|
||||
to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None,
|
||||
args: t.Tuple[str, ...] = ("this",),
|
||||
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
|
||||
def func(self: MySQL.Generator, expression: exp.Func) -> str:
|
||||
expression = expression.copy()
|
||||
|
||||
for arg_key in args:
|
||||
arg = expression.args.get(arg_key)
|
||||
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
|
||||
expression.set(arg_key, arg.this)
|
||||
|
||||
return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class MySQL(Dialect):
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
@ -233,6 +261,7 @@ class MySQL(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
|
@ -240,14 +269,33 @@ class MySQL(Dialect):
|
|||
"ISNULL": isnull_to_is_null,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"MONTHNAME": lambda args: exp.TimeToStr(
|
||||
this=seq_get(args, 0),
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
format=exp.Literal.string("%B"),
|
||||
),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
"TO_DAYS": lambda args: exp.paren(
|
||||
exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")),
|
||||
unit=exp.var("DAY"),
|
||||
)
|
||||
+ 1
|
||||
),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"WEEK": lambda args: exp.Week(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
|
||||
),
|
||||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"CHAR": lambda self: self._parse_chr(),
|
||||
"GROUP_CONCAT": lambda self: self.expression(
|
||||
exp.GroupConcat,
|
||||
this=self._parse_lambda(),
|
||||
|
@ -531,6 +579,18 @@ class MySQL(Dialect):
|
|||
|
||||
return super()._parse_type(parse_interval=parse_interval)
|
||||
|
||||
def _parse_chr(self) -> t.Optional[exp.Expression]:
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)}
|
||||
|
||||
if len(expressions) > 1:
|
||||
kwargs["expressions"] = expressions[1:]
|
||||
|
||||
if self._match(TokenType.USING):
|
||||
kwargs["charset"] = self._parse_var()
|
||||
|
||||
return self.expression(exp.Chr, **kwargs)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
@ -544,25 +604,33 @@ class MySQL(Dialect):
|
|||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateDiff: _remove_ts_or_ds_to_date(
|
||||
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
|
||||
),
|
||||
exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
|
||||
exp.DateTrunc: _date_trunc_sql,
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||
exp.DayOfYear: rename_func("DAYOFYEAR"),
|
||||
exp.Day: _remove_ts_or_ds_to_date(),
|
||||
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
|
||||
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
|
||||
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
|
||||
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONKeyValue: json_keyvalue_comma_sql,
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Month: _remove_ts_or_ds_to_date(),
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.StrPosition: strposition_to_locate_sql,
|
||||
exp.StrToDate: _str_to_date_sql,
|
||||
|
@ -573,10 +641,16 @@ class MySQL(Dialect):
|
|||
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
|
||||
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
|
||||
exp.TimeToStr: _remove_ts_or_ds_to_date(
|
||||
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
|
||||
),
|
||||
exp.Trim: _trim_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
|
||||
exp.TsOrDsAdd: _date_add_sql("ADD"),
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.Week: _remove_ts_or_ds_to_date(),
|
||||
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
|
||||
exp.Year: _remove_ts_or_ds_to_date(),
|
||||
}
|
||||
|
||||
UNSIGNED_TYPE_MAPPING = {
|
||||
|
@ -585,6 +659,7 @@ class MySQL(Dialect):
|
|||
exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
|
||||
exp.DataType.Type.USMALLINT: "SMALLINT",
|
||||
exp.DataType.Type.UTINYINT: "TINYINT",
|
||||
exp.DataType.Type.UDECIMAL: "DECIMAL",
|
||||
}
|
||||
|
||||
TIMESTAMP_TYPE_MAPPING = {
|
||||
|
@ -717,3 +792,9 @@ class MySQL(Dialect):
|
|||
limit_offset = f"{offset}, {limit}" if offset else limit
|
||||
return f" LIMIT {limit_offset}"
|
||||
return ""
|
||||
|
||||
def chr_sql(self, expression: exp.Chr) -> str:
|
||||
this = self.expressions(sqls=[expression.this] + expression.expressions)
|
||||
charset = expression.args.get("charset")
|
||||
using = f" USING {self.sql(charset)}" if charset else ""
|
||||
return f"CHAR({this}{using})"
|
||||
|
|
|
@ -153,6 +153,7 @@ class Oracle(Dialect):
|
|||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
COLUMN_JOIN_MARKS_SUPPORTED = True
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = True
|
||||
|
||||
LIMIT_FETCH = "FETCH"
|
||||
|
||||
|
@ -179,7 +180,12 @@ class Oracle(Dialect):
|
|||
),
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
exp.Substring: rename_func("SUBSTR"),
|
||||
|
|
|
@ -22,6 +22,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
simplify_literal,
|
||||
str_position_sql,
|
||||
struct_extract_sql,
|
||||
timestamptrunc_sql,
|
||||
timestrtotime_sql,
|
||||
trim_sql,
|
||||
|
@ -248,11 +249,10 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
|
||||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
HEREDOC_STRINGS = ["$"]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -296,7 +296,7 @@ class Postgres(Dialect):
|
|||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
"$": TokenType.PARAMETER,
|
||||
"$": TokenType.HEREDOC_STRING,
|
||||
}
|
||||
|
||||
VAR_SINGLE_TOKENS = {"$"}
|
||||
|
@ -420,9 +420,15 @@ class Postgres(Dialect):
|
|||
exp.Pow: lambda self, e: self.binary(e, "^"),
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.StructExtract: struct_extract_sql,
|
||||
exp.Substring: _substring_sql,
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
|
|
@ -309,6 +309,9 @@ class Presto(Dialect):
|
|||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.First: _first_last_sql,
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.GroupConcat: lambda self, e: self.func(
|
||||
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
|
||||
),
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql(),
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
|
|
@ -83,7 +83,7 @@ class Redshift(Postgres):
|
|||
class Tokenizer(Postgres.Tokenizer):
|
||||
BIT_STRINGS = []
|
||||
HEX_STRINGS = []
|
||||
STRING_ESCAPES = ["\\"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
|
||||
KEYWORDS = {
|
||||
**Postgres.Tokenizer.KEYWORDS,
|
||||
|
|
|
@ -239,6 +239,8 @@ class Snowflake(Dialect):
|
|||
class Parser(parser.Parser):
|
||||
IDENTIFY_PIVOT_STRINGS = True
|
||||
|
||||
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
|
||||
|
@ -318,6 +320,43 @@ class Snowflake(Dialect):
|
|||
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
|
||||
}
|
||||
|
||||
STAGED_FILE_SINGLE_TOKENS = {
|
||||
TokenType.DOT,
|
||||
TokenType.MOD,
|
||||
TokenType.SLASH,
|
||||
}
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
# https://docs.snowflake.com/en/user-guide/querying-stage
|
||||
table: t.Optional[exp.Expression] = None
|
||||
if self._match_text_seq("@"):
|
||||
table_name = "@"
|
||||
while True:
|
||||
self._advance()
|
||||
table_name += self._prev.text
|
||||
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
|
||||
break
|
||||
while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
|
||||
table_name += self._prev.text
|
||||
|
||||
table = exp.var(table_name)
|
||||
elif self._match(TokenType.STRING, advance=False):
|
||||
table = self._parse_string()
|
||||
|
||||
if table:
|
||||
file_format = None
|
||||
pattern = None
|
||||
|
||||
if self._match_text_seq("(", "FILE_FORMAT", "=>"):
|
||||
file_format = self._parse_string() or super()._parse_table_parts()
|
||||
if self._match_text_seq(",", "PATTERN", "=>"):
|
||||
pattern = self._parse_string()
|
||||
self._match_r_paren()
|
||||
|
||||
return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
|
||||
|
||||
return super()._parse_table_parts(schema=schema)
|
||||
|
||||
def _parse_id_var(
|
||||
self,
|
||||
any_token: bool = True,
|
||||
|
@ -394,6 +433,8 @@ class Snowflake(Dialect):
|
|||
TABLE_HINTS = False
|
||||
QUERY_HINTS = False
|
||||
AGGREGATE_FILTER_SUPPORTED = False
|
||||
SUPPORTS_TABLE_COPY = False
|
||||
COLLATE_IS_FUNC = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
|
@ -423,6 +464,12 @@ class Snowflake(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
|
||||
exp.PercentileCont: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.PercentileDisc: transforms.preprocess(
|
||||
[transforms.add_within_group_for_percentiles]
|
||||
),
|
||||
exp.RegexpILike: _regexpilike_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
|
|
@ -54,6 +54,14 @@ class Spark(Spark2):
|
|||
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
|
||||
FUNCTION_PARSERS.pop("ANY_VALUE")
|
||||
|
||||
def _parse_generated_as_identity(
|
||||
self,
|
||||
) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
|
||||
this = super()._parse_generated_as_identity()
|
||||
if this.expression:
|
||||
return self.expression(exp.ComputedColumnConstraint, this=this.expression)
|
||||
return this
|
||||
|
||||
class Generator(Spark2.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Spark2.Generator.TYPE_MAPPING,
|
||||
|
@ -73,6 +81,9 @@ class Spark(Spark2):
|
|||
TRANSFORMS.pop(exp.DateDiff)
|
||||
TRANSFORMS.pop(exp.Group)
|
||||
|
||||
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
|
||||
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
|
||||
|
||||
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
|
||||
return self.function_fallback_sql(expression)
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import typing as t
|
|||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
binary_from_function,
|
||||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
is_parse_json,
|
||||
move_insert_cte_sql,
|
||||
|
@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive
|
|||
from sqlglot.helper import seq_get
|
||||
|
||||
|
||||
def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
|
||||
kind = e.args["kind"]
|
||||
properties = e.args.get("properties")
|
||||
|
||||
if (
|
||||
kind.upper() == "TABLE"
|
||||
and e.expression
|
||||
and any(
|
||||
isinstance(prop, exp.TemporaryProperty)
|
||||
for prop in (properties.expressions if properties else [])
|
||||
)
|
||||
):
|
||||
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
|
||||
return create_with_partitions_sql(self, e)
|
||||
|
||||
|
||||
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
|
||||
keys = expression.args.get("keys")
|
||||
values = expression.args.get("values")
|
||||
|
@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
class Spark2(Hive):
|
||||
class Parser(Hive.Parser):
|
||||
TRIM_PATTERN_FIRST = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**Hive.Parser.FUNCTIONS,
|
||||
"AGGREGATE": exp.Reduce.from_arg_list,
|
||||
|
@ -192,7 +177,6 @@ class Spark2(Hive):
|
|||
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
|
||||
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
|
||||
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
|
||||
exp.Create: _create_sql,
|
||||
exp.DateFromParts: rename_func("MAKE_DATE"),
|
||||
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
|
||||
exp.DayOfMonth: rename_func("DAYOFMONTH"),
|
||||
|
@ -236,6 +220,12 @@ class Spark2(Hive):
|
|||
WRAP_DERIVED_VALUES = False
|
||||
CREATE_FUNCTION_RETURN_AS = False
|
||||
|
||||
def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
|
||||
# spark2, spark, Databricks require a storage provider for temporary tables
|
||||
provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
|
||||
expression.args["properties"].append("expressions", provider)
|
||||
return expression
|
||||
|
||||
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
|
||||
if is_parse_json(expression.this):
|
||||
schema = f"'{self.sql(expression, 'to')}'"
|
||||
|
|
|
@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
|
|||
parse_date_delta,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
ts_or_ds_to_date_sql,
|
||||
)
|
||||
from sqlglot.expressions import DataType
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -590,6 +591,7 @@ class TSQL(Dialect):
|
|||
NVL2_SUPPORTED = False
|
||||
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
|
||||
LIMIT_FETCH = "FETCH"
|
||||
COMPUTED_COLUMN_WITH_TYPE = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
|
@ -619,7 +621,11 @@ class TSQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.NumberToStr: _format_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
|
||||
[
|
||||
transforms.eliminate_distinct_on,
|
||||
transforms.eliminate_semi_and_anti_joins,
|
||||
transforms.eliminate_qualify,
|
||||
]
|
||||
),
|
||||
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
|
||||
exp.SHA2: lambda self, e: self.func(
|
||||
|
@ -630,6 +636,7 @@ class TSQL(Dialect):
|
|||
exp.TemporaryProperty: lambda self, e: "",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToStr: _format_sql,
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
|
||||
}
|
||||
|
||||
TRANSFORMS.pop(exp.ReturnsProperty)
|
||||
|
|
|
@ -202,4 +202,5 @@ ENV = {
|
|||
"CURRENTTIME": datetime.datetime.now,
|
||||
"CURRENTDATE": datetime.date.today,
|
||||
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
|
||||
"TRIM": null_if_any(lambda this, e=None: this.strip(e)),
|
||||
}
|
||||
|
|
|
@ -52,6 +52,9 @@ class _Expression(type):
|
|||
return klass
|
||||
|
||||
|
||||
SQLGLOT_META = "sqlglot.meta"
|
||||
|
||||
|
||||
class Expression(metaclass=_Expression):
|
||||
"""
|
||||
The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
|
||||
|
@ -266,7 +269,14 @@ class Expression(metaclass=_Expression):
|
|||
if self.comments is None:
|
||||
self.comments = []
|
||||
if comments:
|
||||
self.comments.extend(comments)
|
||||
for comment in comments:
|
||||
_, *meta = comment.split(SQLGLOT_META)
|
||||
if meta:
|
||||
for kv in "".join(meta).split(","):
|
||||
k, *v = kv.split("=")
|
||||
value = v[0].strip() if v else True
|
||||
self.meta[k.strip()] = value
|
||||
self.comments.append(comment)
|
||||
|
||||
def append(self, arg_key: str, value: t.Any) -> None:
|
||||
"""
|
||||
|
@ -1036,11 +1046,14 @@ class Create(DDL):
|
|||
"indexes": False,
|
||||
"no_schema_binding": False,
|
||||
"begin": False,
|
||||
"end": False,
|
||||
"clone": False,
|
||||
}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
|
||||
class Clone(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -1048,6 +1061,7 @@ class Clone(Expression):
|
|||
"kind": False,
|
||||
"shallow": False,
|
||||
"expression": False,
|
||||
"copy": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1610,6 +1624,11 @@ class Identifier(Expression):
|
|||
return self.name
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/current/indexes-opclass.html
|
||||
class Opclass(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
class Index(Expression):
|
||||
arg_types = {
|
||||
"this": False,
|
||||
|
@ -2156,6 +2175,10 @@ class QueryTransform(Expression):
|
|||
}
|
||||
|
||||
|
||||
class SampleProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class SchemaCommentProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -2440,6 +2463,8 @@ class Table(Expression):
|
|||
"hints": False,
|
||||
"system_time": False,
|
||||
"version": False,
|
||||
"format": False,
|
||||
"pattern": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -2465,17 +2490,17 @@ class Table(Expression):
|
|||
return []
|
||||
|
||||
@property
|
||||
def parts(self) -> t.List[Identifier]:
|
||||
def parts(self) -> t.List[Expression]:
|
||||
"""Return the parts of a table in order catalog, db, table."""
|
||||
parts: t.List[Identifier] = []
|
||||
parts: t.List[Expression] = []
|
||||
|
||||
for arg in ("catalog", "db", "this"):
|
||||
part = self.args.get(arg)
|
||||
|
||||
if isinstance(part, Identifier):
|
||||
parts.append(part)
|
||||
elif isinstance(part, Dot):
|
||||
if isinstance(part, Dot):
|
||||
parts.extend(part.flatten())
|
||||
elif isinstance(part, Expression):
|
||||
parts.append(part)
|
||||
|
||||
return parts
|
||||
|
||||
|
@ -2910,6 +2935,7 @@ class Select(Subqueryable):
|
|||
prefix="OFFSET",
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
into_arg="expression",
|
||||
**opts,
|
||||
)
|
||||
|
||||
|
@ -3572,6 +3598,7 @@ class DataType(Expression):
|
|||
UINT128 = auto()
|
||||
UINT256 = auto()
|
||||
UMEDIUMINT = auto()
|
||||
UDECIMAL = auto()
|
||||
UNIQUEIDENTIFIER = auto()
|
||||
UNKNOWN = auto() # Sentinel value, useful for type annotation
|
||||
USERDEFINED = "USER-DEFINED"
|
||||
|
@ -3693,13 +3720,13 @@ class DataType(Expression):
|
|||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
class PseudoType(Expression):
|
||||
pass
|
||||
class PseudoType(DataType):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-oid.html
|
||||
class ObjectIdentifier(Expression):
|
||||
pass
|
||||
class ObjectIdentifier(DataType):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
|
||||
|
@ -4027,10 +4054,20 @@ class TimeUnit(Expression):
|
|||
return self.args.get("unit")
|
||||
|
||||
|
||||
class IntervalOp(TimeUnit):
|
||||
arg_types = {"unit": True, "expression": True}
|
||||
|
||||
def interval(self):
|
||||
return Interval(
|
||||
this=self.expression.copy(),
|
||||
unit=self.unit.copy(),
|
||||
)
|
||||
|
||||
|
||||
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
|
||||
# https://trino.io/docs/current/language/types.html#interval-day-to-second
|
||||
# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html
|
||||
class IntervalSpan(Expression):
|
||||
class IntervalSpan(DataType):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
|
@ -4269,7 +4306,7 @@ class CastToStrType(Func):
|
|||
arg_types = {"this": True, "to": True}
|
||||
|
||||
|
||||
class Collate(Binary):
|
||||
class Collate(Binary, Func):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -4284,6 +4321,12 @@ class Coalesce(Func):
|
|||
_sql_names = ["COALESCE", "IFNULL", "NVL"]
|
||||
|
||||
|
||||
class Chr(Func):
|
||||
arg_types = {"this": True, "charset": False, "expressions": False}
|
||||
is_var_len_args = True
|
||||
_sql_names = ["CHR", "CHAR"]
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
arg_types = {"expressions": True}
|
||||
is_var_len_args = True
|
||||
|
@ -4326,11 +4369,11 @@ class CurrentUser(Func):
|
|||
arg_types = {"this": False}
|
||||
|
||||
|
||||
class DateAdd(Func, TimeUnit):
|
||||
class DateAdd(Func, IntervalOp):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
class DateSub(Func, TimeUnit):
|
||||
class DateSub(Func, IntervalOp):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
|
@ -4347,11 +4390,11 @@ class DateTrunc(Func):
|
|||
return self.args["unit"]
|
||||
|
||||
|
||||
class DatetimeAdd(Func, TimeUnit):
|
||||
class DatetimeAdd(Func, IntervalOp):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
class DatetimeSub(Func, TimeUnit):
|
||||
class DatetimeSub(Func, IntervalOp):
|
||||
arg_types = {"this": True, "expression": True, "unit": False}
|
||||
|
||||
|
||||
|
@ -4375,6 +4418,10 @@ class DayOfYear(Func):
|
|||
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
|
||||
|
||||
|
||||
class ToDays(Func):
|
||||
pass
|
||||
|
||||
|
||||
class WeekOfYear(Func):
|
||||
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
|
||||
|
||||
|
@ -6160,7 +6207,7 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
|
|||
The table name.
|
||||
"""
|
||||
|
||||
table = maybe_parse(table, into=Table)
|
||||
table = maybe_parse(table, into=Table, dialect=dialect)
|
||||
|
||||
if not table:
|
||||
raise ValueError(f"Cannot parse {table}")
|
||||
|
|
|
@ -86,6 +86,7 @@ class Generator:
|
|||
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
|
||||
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
|
||||
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
|
||||
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
|
||||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
|
@ -204,6 +205,21 @@ class Generator:
|
|||
# Whether or not session variables / parameters are supported, e.g. @x in T-SQL
|
||||
SUPPORTS_PARAMETERS = True
|
||||
|
||||
# Whether or not to include the type of a computed column in the CREATE DDL
|
||||
COMPUTED_COLUMN_WITH_TYPE = True
|
||||
|
||||
# Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
|
||||
SUPPORTS_TABLE_COPY = True
|
||||
|
||||
# Whether or not parentheses are required around the table sample's expression
|
||||
TABLESAMPLE_REQUIRES_PARENS = True
|
||||
|
||||
# Whether or not COLLATE is a function instead of a binary operator
|
||||
COLLATE_IS_FUNC = False
|
||||
|
||||
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
|
||||
DATA_TYPE_SPECIFIERS_ALLOWED = False
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -282,6 +298,7 @@ class Generator:
|
|||
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SampleProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.Set: exp.Properties.Location.POST_SCHEMA,
|
||||
|
@ -324,13 +341,12 @@ class Generator:
|
|||
exp.Paren,
|
||||
)
|
||||
|
||||
UNESCAPED_SEQUENCE_TABLE = None # type: ignore
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
# Autofilled
|
||||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
INDEX_OFFSET = 0
|
||||
UNNEST_COLUMN_ONLY = False
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
|
@ -480,8 +496,7 @@ class Generator:
|
|||
if not comments or isinstance(expression, exp.Binary):
|
||||
return sql
|
||||
|
||||
sep = "\n" if self.pretty else " "
|
||||
comments_sql = sep.join(
|
||||
comments_sql = " ".join(
|
||||
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
|
||||
)
|
||||
|
||||
|
@ -649,6 +664,9 @@ class Generator:
|
|||
position = self.sql(expression, "position")
|
||||
position = f" {position}" if position else ""
|
||||
|
||||
if expression.find(exp.ComputedColumnConstraint) and not self.COMPUTED_COLUMN_WITH_TYPE:
|
||||
kind = ""
|
||||
|
||||
return f"{exists}{column}{kind}{constraints}{position}"
|
||||
|
||||
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
||||
|
@ -750,9 +768,11 @@ class Generator:
|
|||
)
|
||||
|
||||
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||
end = " END" if expression.args.get("end") else ""
|
||||
|
||||
expression_sql = self.sql(expression, "expression")
|
||||
if expression_sql:
|
||||
expression_sql = f"{begin}{self.sep()}{expression_sql}"
|
||||
expression_sql = f"{begin}{self.sep()}{expression_sql}{end}"
|
||||
|
||||
if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
|
||||
if properties_locs.get(exp.Properties.Location.POST_ALIAS):
|
||||
|
@ -817,7 +837,8 @@ class Generator:
|
|||
def clone_sql(self, expression: exp.Clone) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
shallow = "SHALLOW " if expression.args.get("shallow") else ""
|
||||
this = f"{shallow}CLONE {this}"
|
||||
keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE"
|
||||
this = f"{shallow}{keyword} {this}"
|
||||
when = self.sql(expression, "when")
|
||||
|
||||
if when:
|
||||
|
@ -877,7 +898,7 @@ class Generator:
|
|||
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
specifier = self.sql(expression, "expression")
|
||||
specifier = f" {specifier}" if specifier else ""
|
||||
specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else ""
|
||||
return f"{this}{specifier}"
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
|
@ -1329,8 +1350,13 @@ class Generator:
|
|||
pivots = f" {pivots}" if pivots else ""
|
||||
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
|
||||
laterals = self.expressions(expression, key="laterals", sep="")
|
||||
file_format = self.sql(expression, "format")
|
||||
if file_format:
|
||||
pattern = self.sql(expression, "pattern")
|
||||
pattern = f", PATTERN => {pattern}" if pattern else ""
|
||||
file_format = f" (FILE_FORMAT => {file_format}{pattern})"
|
||||
|
||||
return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}"
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
|
@ -1343,6 +1369,7 @@ class Generator:
|
|||
else:
|
||||
this = self.sql(expression, "this")
|
||||
alias = ""
|
||||
|
||||
method = self.sql(expression, "method")
|
||||
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
|
||||
numerator = self.sql(expression, "bucket_numerator")
|
||||
|
@ -1354,13 +1381,20 @@ class Generator:
|
|||
percent = f"{percent} PERCENT" if percent else ""
|
||||
rows = self.sql(expression, "rows")
|
||||
rows = f"{rows} ROWS" if rows else ""
|
||||
|
||||
size = self.sql(expression, "size")
|
||||
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
|
||||
size = f"{size} PERCENT"
|
||||
|
||||
seed = self.sql(expression, "seed")
|
||||
seed = f" {seed_prefix} ({seed})" if seed else ""
|
||||
kind = expression.args.get("kind", "TABLESAMPLE")
|
||||
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
|
||||
|
||||
expr = f"{bucket}{percent}{rows}{size}"
|
||||
if self.TABLESAMPLE_REQUIRES_PARENS:
|
||||
expr = f"({expr})"
|
||||
|
||||
return f"{this} {kind} {method}{expr}{seed}{alias}"
|
||||
|
||||
def pivot_sql(self, expression: exp.Pivot) -> str:
|
||||
expressions = self.expressions(expression, flat=True)
|
||||
|
@ -1638,8 +1672,8 @@ class Generator:
|
|||
|
||||
def escape_str(self, text: str) -> str:
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.UNESCAPED_SEQUENCE_TABLE:
|
||||
text = text.translate(self.UNESCAPED_SEQUENCE_TABLE)
|
||||
if self.INVERSE_ESCAPE_SEQUENCES:
|
||||
text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
|
||||
elif self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
return text
|
||||
|
@ -2301,6 +2335,8 @@ class Generator:
|
|||
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
|
||||
|
||||
def collate_sql(self, expression: exp.Collate) -> str:
|
||||
if self.COLLATE_IS_FUNC:
|
||||
return self.function_fallback_sql(expression)
|
||||
return self.binary(expression, "COLLATE")
|
||||
|
||||
def command_sql(self, expression: exp.Command) -> str:
|
||||
|
@ -2359,7 +2395,7 @@ class Generator:
|
|||
collate = f" COLLATE {collate}" if collate else ""
|
||||
using = self.sql(expression, "using")
|
||||
using = f" USING {using}" if using else ""
|
||||
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
|
||||
return f"ALTER COLUMN {this} SET DATA TYPE {dtype}{collate}{using}"
|
||||
|
||||
default = self.sql(expression, "default")
|
||||
if default:
|
||||
|
@ -2396,7 +2432,7 @@ class Generator:
|
|||
elif isinstance(actions[0], exp.Delete):
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
else:
|
||||
actions = self.expressions(expression, key="actions")
|
||||
actions = self.expressions(expression, key="actions", flat=True)
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
only = " ONLY" if expression.args.get("only") else ""
|
||||
|
@ -2593,7 +2629,7 @@ class Generator:
|
|||
self,
|
||||
expression: t.Optional[exp.Expression] = None,
|
||||
key: t.Optional[str] = None,
|
||||
sqls: t.Optional[t.List[str]] = None,
|
||||
sqls: t.Optional[t.Collection[str | exp.Expression]] = None,
|
||||
flat: bool = False,
|
||||
indent: bool = True,
|
||||
skip_first: bool = False,
|
||||
|
@ -2841,6 +2877,9 @@ class Generator:
|
|||
def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str:
|
||||
return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})"
|
||||
|
||||
def opclass_sql(self, expression: exp.Opclass) -> str:
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
|
||||
def cached_generator(
|
||||
cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
|
@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema
|
|||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
|
||||
BinaryCoercions = t.Dict[
|
||||
t.Tuple[exp.DataType.Type, exp.DataType.Type],
|
||||
BinaryCoercionFunc,
|
||||
]
|
||||
|
||||
|
||||
# Interval units that operate on date components
|
||||
DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
|
||||
|
||||
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
|
@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
|
|||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
|
||||
def _is_iso_date(text: str) -> bool:
|
||||
try:
|
||||
datetime.date.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_iso_datetime(text: str) -> bool:
|
||||
try:
|
||||
datetime.datetime.fromisoformat(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
date_text = l.name
|
||||
unit = r.text("unit").lower()
|
||||
|
||||
is_iso_date = _is_iso_date(date_text)
|
||||
|
||||
if is_iso_date and unit in DATE_UNITS:
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
|
||||
return exp.DataType.Type.DATE
|
||||
|
||||
# An ISO date is also an ISO datetime, but not vice versa
|
||||
if is_iso_date or _is_iso_datetime(date_text):
|
||||
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
|
||||
return exp.DataType.Type.DATETIME
|
||||
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
unit = r.text("unit").lower()
|
||||
if unit not in DATE_UNITS:
|
||||
return exp.DataType.Type.DATETIME
|
||||
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
|
||||
|
||||
|
||||
def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
|
||||
@functools.wraps(func)
|
||||
def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
|
||||
return func(r, l)
|
||||
|
||||
return _swapped
|
||||
|
||||
|
||||
def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
|
||||
return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
|
||||
|
||||
|
||||
class _TypeAnnotator(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.DataType.Type.DATE: {
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateAdd,
|
||||
exp.DateFromParts,
|
||||
exp.DateStrToDate,
|
||||
exp.DateSub,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
|
@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.DateSub: lambda self, e: self._annotate_dateadd(e),
|
||||
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
|
@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
# Coercion functions for binary operations.
|
||||
# Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
|
||||
BINARY_COERCIONS: BinaryCoercions = {
|
||||
**swap_all(
|
||||
{
|
||||
(t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
|
||||
for t in exp.DataType.TEXT_TYPES
|
||||
}
|
||||
),
|
||||
**swap_all(
|
||||
{
|
||||
(exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
binary_coercions: t.Optional[BinaryCoercions] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
|
||||
|
||||
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
|
||||
self._visited: t.Set[int] = set()
|
||||
|
||||
def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
|
||||
expression.type = target_type
|
||||
def _set_type(
|
||||
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
|
||||
) -> None:
|
||||
expression.type = target_type # type: ignore
|
||||
self._visited.add(id(expression))
|
||||
|
||||
def annotate(self, expression: E) -> E:
|
||||
|
@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type = expression.left.type.this
|
||||
right_type = expression.right.type.this
|
||||
left, right = expression.left, expression.right
|
||||
left_type, right_type = left.type.this, right.type.this
|
||||
|
||||
if isinstance(expression, exp.Connector):
|
||||
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
|
||||
|
@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif isinstance(expression, exp.Predicate):
|
||||
self._set_type(expression, exp.DataType.Type.BOOLEAN)
|
||||
elif (left_type, right_type) in self.binary_coercions:
|
||||
self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
|
||||
else:
|
||||
self._set_type(expression, self._maybe_coerce(left_type, right_type))
|
||||
|
||||
|
@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
|
|||
)
|
||||
|
||||
return expression
|
||||
|
||||
def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if expression.this.type.this in exp.DataType.TEXT_TYPES:
|
||||
datatype = _coerce_literal_and_interval(expression.this, expression.interval())
|
||||
elif (
|
||||
expression.this.type.is_type(exp.DataType.Type.DATE)
|
||||
and expression.text("unit").lower() not in DATE_UNITS
|
||||
):
|
||||
datatype = exp.DataType.Type.DATETIME
|
||||
else:
|
||||
datatype = expression.this.type
|
||||
|
||||
self._set_type(expression, datatype)
|
||||
return expression
|
||||
|
|
|
@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
|
|||
_coerce_date(node.left, node.right)
|
||||
elif isinstance(node, exp.Between):
|
||||
_coerce_date(node.this, node.args["low"])
|
||||
elif isinstance(node, exp.Extract):
|
||||
if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
|
||||
_replace_cast(node.expression, "datetime")
|
||||
elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
|
||||
*exp.DataType.TEMPORAL_TYPES
|
||||
):
|
||||
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
|
|||
_replace_int_predicate(expression.left)
|
||||
_replace_int_predicate(expression.right)
|
||||
|
||||
elif isinstance(expression, (exp.Where, exp.Having)):
|
||||
elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
|
||||
_replace_int_predicate(expression.this)
|
||||
|
||||
return expression
|
||||
|
@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
|
|||
and b.type
|
||||
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
|
||||
):
|
||||
_replace_cast(b, "date")
|
||||
_replace_cast(b, exp.DataType.Type.DATE)
|
||||
|
||||
|
||||
def _replace_cast(node: exp.Expression, to: str) -> None:
|
||||
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
|
||||
node.replace(exp.cast(node.copy(), to=to))
|
||||
|
||||
|
||||
def _replace_int_predicate(expression: exp.Expression) -> None:
|
||||
if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
if isinstance(expression, exp.Coalesce):
|
||||
for _, child in expression.iter_expressions():
|
||||
_replace_int_predicate(child)
|
||||
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
|
||||
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
|
||||
|
|
|
@ -181,7 +181,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
|
||||
and inner_select.args.get("from")
|
||||
and not outer_scope.pivots
|
||||
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
|
||||
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
|
||||
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
|
||||
and not (
|
||||
isinstance(from_or_join, exp.Join)
|
||||
|
|
|
@ -22,6 +22,13 @@ def normalize_identifiers(expression, dialect=None):
|
|||
Normalize all unquoted identifiers to either lower or upper case, depending
|
||||
on the dialect. This essentially makes those identifiers case-insensitive.
|
||||
|
||||
It's possible to make this a no-op by adding a special comment next to the
|
||||
identifier of interest:
|
||||
|
||||
SELECT a /* sqlglot.meta case_sensitive */ FROM table
|
||||
|
||||
In this example, the identifier `a` will not be normalized.
|
||||
|
||||
Note:
|
||||
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
|
||||
when they're quoted, so in these cases all identifiers are normalized.
|
||||
|
@ -43,4 +50,13 @@ def normalize_identifiers(expression, dialect=None):
|
|||
"""
|
||||
if isinstance(expression, str):
|
||||
expression = exp.to_identifier(expression)
|
||||
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
def _normalize(node: E) -> E:
|
||||
if not node.meta.get("case_sensitive"):
|
||||
exp.replace_children(node, _normalize)
|
||||
node = dialect.normalize_identifier(node)
|
||||
return node
|
||||
|
||||
return _normalize(expression)
|
||||
|
|
|
@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
|
|||
return expression.is_number
|
||||
|
||||
|
||||
def _is_date(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Cast) and extract_date(expression) is not None
|
||||
|
||||
|
||||
def _is_interval(expression: exp.Expression) -> bool:
|
||||
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
|
||||
|
||||
|
@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
if r.is_number:
|
||||
a_predicate = _is_number
|
||||
b_predicate = _is_number
|
||||
elif _is_date(r):
|
||||
a_predicate = _is_date
|
||||
elif _is_date_literal(r):
|
||||
a_predicate = _is_date_literal
|
||||
b_predicate = _is_interval
|
||||
else:
|
||||
return expression
|
||||
|
||||
if l.__class__ in INVERSE_DATE_OPS:
|
||||
a = l.this
|
||||
b = exp.Interval(
|
||||
this=l.expression.copy(),
|
||||
unit=l.unit.copy(),
|
||||
)
|
||||
b = l.interval()
|
||||
else:
|
||||
a, b = l.left, l.right
|
||||
|
||||
|
@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b):
|
|||
|
||||
if boolean:
|
||||
return boolean
|
||||
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
|
||||
elif _is_date_literal(a) and isinstance(b, exp.Interval):
|
||||
a, b = extract_date(a), extract_interval(b)
|
||||
if a and b:
|
||||
if isinstance(expression, exp.Add):
|
||||
return date_literal(a + b)
|
||||
if isinstance(expression, exp.Sub):
|
||||
return date_literal(a - b)
|
||||
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
|
||||
elif isinstance(a, exp.Interval) and _is_date_literal(b):
|
||||
a, b = extract_interval(a), extract_date(b)
|
||||
# you cannot subtract a date from an interval
|
||||
if a and b and isinstance(expression, exp.Add):
|
||||
|
@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
|||
|
||||
|
||||
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
||||
return (
|
||||
isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
|
||||
and isinstance(right, exp.Cast)
|
||||
and right.is_type(*exp.DataType.TEMPORAL_TYPES)
|
||||
)
|
||||
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
|
@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
unit = l.unit.name.lower()
|
||||
date = extract_date(r)
|
||||
|
||||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
||||
if all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
|
||||
unit = l.unit.name.lower()
|
||||
|
||||
ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
|
||||
ranges = []
|
||||
for r in rs:
|
||||
date = extract_date(r)
|
||||
if not date:
|
||||
return expression
|
||||
drange = _datetrunc_range(date, unit)
|
||||
if drange:
|
||||
ranges.append(drange)
|
||||
|
||||
if not ranges:
|
||||
return expression
|
||||
|
||||
|
@ -811,18 +811,59 @@ def eval_boolean(expression, a, b):
|
|||
return None
|
||||
|
||||
|
||||
def extract_date(cast):
|
||||
# The "fromisoformat" conversion could fail if the cast is used on an identifier,
|
||||
# so in that case we can't extract the date.
|
||||
def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value.date()
|
||||
if isinstance(value, datetime.date):
|
||||
return value
|
||||
try:
|
||||
if cast.args["to"].this == exp.DataType.Type.DATE:
|
||||
return datetime.date.fromisoformat(cast.name)
|
||||
if cast.args["to"].this == exp.DataType.Type.DATETIME:
|
||||
return datetime.datetime.fromisoformat(cast.name)
|
||||
return datetime.datetime.fromisoformat(value).date()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
if isinstance(value, datetime.date):
|
||||
return datetime.datetime(year=value.year, month=value.month, day=value.day)
|
||||
try:
|
||||
return datetime.datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if not value:
|
||||
return None
|
||||
if to.is_type(exp.DataType.Type.DATE):
|
||||
return cast_as_date(value)
|
||||
if to.is_type(*exp.DataType.TEMPORAL_TYPES):
|
||||
return cast_as_datetime(value)
|
||||
return None
|
||||
|
||||
|
||||
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
|
||||
if isinstance(cast, exp.Cast):
|
||||
to = cast.to
|
||||
elif isinstance(cast, exp.TsOrDsToDate):
|
||||
to = exp.DataType.build(exp.DataType.Type.DATE)
|
||||
else:
|
||||
return None
|
||||
|
||||
if isinstance(cast.this, exp.Literal):
|
||||
value: t.Any = cast.this.name
|
||||
elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
|
||||
value = extract_date(cast.this)
|
||||
else:
|
||||
return None
|
||||
return cast_value(value, to)
|
||||
|
||||
|
||||
def _is_date_literal(expression: exp.Expression) -> bool:
|
||||
return extract_date(expression) is not None
|
||||
|
||||
|
||||
def extract_interval(expression):
|
||||
n = int(expression.name)
|
||||
unit = expression.text("unit").lower()
|
||||
|
@ -836,7 +877,9 @@ def extract_interval(expression):
|
|||
def date_literal(date):
|
||||
return exp.cast(
|
||||
exp.Literal.string(date),
|
||||
"DATETIME" if isinstance(date, datetime.datetime) else "DATE",
|
||||
exp.DataType.Type.DATETIME
|
||||
if isinstance(date, datetime.datetime)
|
||||
else exp.DataType.Type.DATE,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -178,6 +178,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.DATERANGE,
|
||||
TokenType.DATEMULTIRANGE,
|
||||
TokenType.DECIMAL,
|
||||
TokenType.UDECIMAL,
|
||||
TokenType.BIGDECIMAL,
|
||||
TokenType.UUID,
|
||||
TokenType.GEOGRAPHY,
|
||||
|
@ -215,6 +216,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.MEDIUMINT: TokenType.UMEDIUMINT,
|
||||
TokenType.SMALLINT: TokenType.USMALLINT,
|
||||
TokenType.TINYINT: TokenType.UTINYINT,
|
||||
TokenType.DECIMAL: TokenType.UDECIMAL,
|
||||
}
|
||||
|
||||
SUBQUERY_PREDICATES = {
|
||||
|
@ -338,6 +340,7 @@ class Parser(metaclass=_Parser):
|
|||
TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"}
|
||||
|
||||
FUNC_TOKENS = {
|
||||
TokenType.COLLATE,
|
||||
TokenType.COMMAND,
|
||||
TokenType.CURRENT_DATE,
|
||||
TokenType.CURRENT_DATETIME,
|
||||
|
@ -590,6 +593,9 @@ class Parser(metaclass=_Parser):
|
|||
exp.National, this=token.text
|
||||
),
|
||||
TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
|
||||
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
|
||||
exp.RawString, this=token.text
|
||||
),
|
||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||
}
|
||||
|
||||
|
@ -666,6 +672,9 @@ class Parser(metaclass=_Parser):
|
|||
"RETURNS": lambda self: self._parse_returns(),
|
||||
"ROW": lambda self: self._parse_row(),
|
||||
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
|
||||
"SAMPLE": lambda self: self.expression(
|
||||
exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise()
|
||||
),
|
||||
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
|
||||
"SETTINGS": lambda self: self.expression(
|
||||
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
|
||||
|
@ -847,8 +856,11 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
CLONE_KEYWORDS = {"CLONE", "COPY"}
|
||||
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
|
||||
|
||||
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
|
||||
|
||||
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
|
||||
|
||||
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
|
||||
|
@ -863,6 +875,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
NULL_TOKENS = {TokenType.NULL}
|
||||
|
||||
UNNEST_OFFSET_ALIAS_TOKENS = ID_VAR_TOKENS - SET_OPERATIONS
|
||||
|
||||
STRICT_CAST = True
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default
|
||||
|
@ -880,9 +894,12 @@ class Parser(metaclass=_Parser):
|
|||
# Whether or not the table sample clause expects CSV syntax
|
||||
TABLESAMPLE_CSV = False
|
||||
|
||||
# Whether or not the SET command needs a delimiter (e.g. "=") for assignments.
|
||||
# Whether or not the SET command needs a delimiter (e.g. "=") for assignments
|
||||
SET_REQUIRES_ASSIGNMENT_DELIMITER = True
|
||||
|
||||
# Whether the TRIM function expects the characters to trim as its first argument
|
||||
TRIM_PATTERN_FIRST = False
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1268,6 +1285,7 @@ class Parser(metaclass=_Parser):
|
|||
indexes = None
|
||||
no_schema_binding = None
|
||||
begin = None
|
||||
end = None
|
||||
clone = None
|
||||
|
||||
def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
|
||||
|
@ -1299,6 +1317,8 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
expression = self._parse_statement()
|
||||
|
||||
end = self._match_text_seq("END")
|
||||
|
||||
if return_:
|
||||
expression = self.expression(exp.Return, this=expression)
|
||||
elif create_token.token_type == TokenType.INDEX:
|
||||
|
@ -1344,7 +1364,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
shallow = self._match_text_seq("SHALLOW")
|
||||
|
||||
if self._match_text_seq("CLONE"):
|
||||
if self._match_texts(self.CLONE_KEYWORDS):
|
||||
copy = self._prev.text.lower() == "copy"
|
||||
clone = self._parse_table(schema=True)
|
||||
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
|
||||
clone_kind = (
|
||||
|
@ -1361,6 +1382,7 @@ class Parser(metaclass=_Parser):
|
|||
kind=clone_kind,
|
||||
shallow=shallow,
|
||||
expression=clone_expression,
|
||||
copy=copy,
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
|
@ -1376,6 +1398,7 @@ class Parser(metaclass=_Parser):
|
|||
indexes=indexes,
|
||||
no_schema_binding=no_schema_binding,
|
||||
begin=begin,
|
||||
end=end,
|
||||
clone=clone,
|
||||
)
|
||||
|
||||
|
@ -2445,21 +2468,32 @@ class Parser(metaclass=_Parser):
|
|||
kwargs["using"] = self._parse_wrapped_id_vars()
|
||||
elif not (kind and kind.token_type == TokenType.CROSS):
|
||||
index = self._index
|
||||
joins = self._parse_joins()
|
||||
join = self._parse_join()
|
||||
|
||||
if joins and self._match(TokenType.ON):
|
||||
if join and self._match(TokenType.ON):
|
||||
kwargs["on"] = self._parse_conjunction()
|
||||
elif joins and self._match(TokenType.USING):
|
||||
elif join and self._match(TokenType.USING):
|
||||
kwargs["using"] = self._parse_wrapped_id_vars()
|
||||
else:
|
||||
joins = None
|
||||
join = None
|
||||
self._retreat(index)
|
||||
|
||||
kwargs["this"].set("joins", joins)
|
||||
kwargs["this"].set("joins", [join] if join else None)
|
||||
|
||||
comments = [c for token in (method, side, kind) if token for c in token.comments]
|
||||
return self.expression(exp.Join, comments=comments, **kwargs)
|
||||
|
||||
def _parse_opclass(self) -> t.Optional[exp.Expression]:
|
||||
this = self._parse_conjunction()
|
||||
if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
|
||||
return this
|
||||
|
||||
opclass = self._parse_var(any_token=True)
|
||||
if opclass:
|
||||
return self.expression(exp.Opclass, this=this, expression=opclass)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_index(
|
||||
self,
|
||||
index: t.Optional[exp.Expression] = None,
|
||||
|
@ -2486,7 +2520,7 @@ class Parser(metaclass=_Parser):
|
|||
using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None
|
||||
|
||||
if self._match(TokenType.L_PAREN, advance=False):
|
||||
columns = self._parse_wrapped_csv(self._parse_ordered)
|
||||
columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass))
|
||||
else:
|
||||
columns = None
|
||||
|
||||
|
@ -2677,7 +2711,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
|
||||
self._match(TokenType.ALIAS)
|
||||
offset = self._parse_id_var() or exp.to_identifier("offset")
|
||||
offset = self._parse_id_var(
|
||||
any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS
|
||||
) or exp.to_identifier("offset")
|
||||
|
||||
return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset)
|
||||
|
||||
|
@ -2715,14 +2751,18 @@ class Parser(metaclass=_Parser):
|
|||
)
|
||||
method = self._parse_var(tokens=(TokenType.ROW,))
|
||||
|
||||
self._match(TokenType.L_PAREN)
|
||||
matched_l_paren = self._match(TokenType.L_PAREN)
|
||||
|
||||
if self.TABLESAMPLE_CSV:
|
||||
num = None
|
||||
expressions = self._parse_csv(self._parse_primary)
|
||||
else:
|
||||
expressions = None
|
||||
num = self._parse_primary()
|
||||
num = (
|
||||
self._parse_factor()
|
||||
if self._match(TokenType.NUMBER, advance=False)
|
||||
else self._parse_primary()
|
||||
)
|
||||
|
||||
if self._match_text_seq("BUCKET"):
|
||||
bucket_numerator = self._parse_number()
|
||||
|
@ -2737,7 +2777,8 @@ class Parser(metaclass=_Parser):
|
|||
elif num:
|
||||
size = num
|
||||
|
||||
self._match(TokenType.R_PAREN)
|
||||
if matched_l_paren:
|
||||
self._match_r_paren()
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
method = self._parse_var()
|
||||
|
@ -2965,8 +3006,8 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
|
||||
|
||||
def _parse_ordered(self) -> exp.Ordered:
|
||||
this = self._parse_conjunction()
|
||||
def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered:
|
||||
this = parse_method() if parse_method else self._parse_conjunction()
|
||||
|
||||
asc = self._match(TokenType.ASC)
|
||||
desc = self._match(TokenType.DESC) or (asc and False)
|
||||
|
@ -3144,7 +3185,7 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match_text_seq("DISTINCT", "FROM"):
|
||||
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
|
||||
return self.expression(klass, this=this, expression=self._parse_expression())
|
||||
return self.expression(klass, this=this, expression=self._parse_conjunction())
|
||||
|
||||
expression = self._parse_null() or self._parse_boolean()
|
||||
if not expression:
|
||||
|
@ -3760,7 +3801,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
|
||||
|
||||
def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint:
|
||||
def _parse_generated_as_identity(
|
||||
self,
|
||||
) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
|
||||
if self._match_text_seq("BY", "DEFAULT"):
|
||||
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
|
||||
this = self.expression(
|
||||
|
@ -4382,16 +4425,18 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
position = None
|
||||
collation = None
|
||||
expression = None
|
||||
|
||||
if self._match_texts(self.TRIM_TYPES):
|
||||
position = self._prev.text.upper()
|
||||
|
||||
expression = self._parse_bitwise()
|
||||
this = self._parse_bitwise()
|
||||
if self._match_set((TokenType.FROM, TokenType.COMMA)):
|
||||
this = self._parse_bitwise()
|
||||
else:
|
||||
this = expression
|
||||
expression = None
|
||||
invert_order = self._prev.token_type == TokenType.FROM or self.TRIM_PATTERN_FIRST
|
||||
expression = self._parse_bitwise()
|
||||
|
||||
if invert_order:
|
||||
this, expression = expression, this
|
||||
|
||||
if self._match(TokenType.COLLATE):
|
||||
collation = self._parse_bitwise()
|
||||
|
|
|
@ -77,6 +77,7 @@ class TokenType(AutoName):
|
|||
BYTE_STRING = auto()
|
||||
NATIONAL_STRING = auto()
|
||||
RAW_STRING = auto()
|
||||
HEREDOC_STRING = auto()
|
||||
|
||||
# types
|
||||
BIT = auto()
|
||||
|
@ -98,6 +99,7 @@ class TokenType(AutoName):
|
|||
FLOAT = auto()
|
||||
DOUBLE = auto()
|
||||
DECIMAL = auto()
|
||||
UDECIMAL = auto()
|
||||
BIGDECIMAL = auto()
|
||||
CHAR = auto()
|
||||
NCHAR = auto()
|
||||
|
@ -418,6 +420,7 @@ class _Tokenizer(type):
|
|||
**_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
|
||||
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
|
||||
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
|
||||
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
|
||||
}
|
||||
|
||||
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||
|
@ -484,11 +487,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
STRING_ESCAPES = ["'"]
|
||||
VAR_SINGLE_TOKENS: t.Set[str] = set()
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
|
||||
# Autofilled
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
|
||||
|
@ -997,9 +1002,11 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
word = word.upper()
|
||||
self._add(self.KEYWORDS[word], text=word)
|
||||
return
|
||||
|
||||
if self._char in self.SINGLE_TOKENS:
|
||||
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
|
||||
return
|
||||
|
||||
self._scan_var()
|
||||
|
||||
def _scan_comment(self, comment_start: str) -> bool:
|
||||
|
@ -1126,6 +1133,10 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
base = 16
|
||||
elif token_type == TokenType.BIT_STRING:
|
||||
base = 2
|
||||
elif token_type == TokenType.HEREDOC_STRING:
|
||||
self._advance()
|
||||
tag = "" if self._char == end else self._extract_string(end)
|
||||
end = f"{start}{tag}{end}"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -1193,6 +1204,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
if self._end:
|
||||
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
|
||||
|
||||
if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES:
|
||||
escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek)
|
||||
if escaped_sequence:
|
||||
self._advance(2)
|
||||
text += escaped_sequence
|
||||
continue
|
||||
|
||||
current = self._current - 1
|
||||
self._advance(alnum=True)
|
||||
text += self.sql[current : self._current - 1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue