Merging upstream version 10.4.2.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
de4e42d4d3
commit
0c79f8b507
88 changed files with 1637 additions and 436 deletions
|
@ -1,4 +1,6 @@
|
|||
"""## Python SQL parser, transpiler and optimizer."""
|
||||
"""
|
||||
.. include:: ../README.md
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -30,7 +32,7 @@ from sqlglot.parser import Parser
|
|||
from sqlglot.schema import MappingSchema
|
||||
from sqlglot.tokens import Tokenizer, TokenType
|
||||
|
||||
__version__ = "10.2.9"
|
||||
__version__ = "10.4.2"
|
||||
|
||||
pretty = False
|
||||
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
import argparse
|
||||
import sys
|
||||
|
||||
import sqlglot
|
||||
|
||||
parser = argparse.ArgumentParser(description="Transpile SQL")
|
||||
parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
|
||||
parser.add_argument(
|
||||
"sql",
|
||||
metavar="sql",
|
||||
type=str,
|
||||
help="SQL statement(s) to transpile, or - to parse stdin.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--read",
|
||||
dest="read",
|
||||
|
@ -48,14 +54,20 @@ parser.add_argument(
|
|||
args = parser.parse_args()
|
||||
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
|
||||
|
||||
sql = sys.stdin.read() if args.sql == "-" else args.sql
|
||||
|
||||
if args.parse:
|
||||
sqls = [
|
||||
repr(expression)
|
||||
for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
|
||||
for expression in sqlglot.parse(
|
||||
sql,
|
||||
read=args.read,
|
||||
error_level=error_level,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sqls = sqlglot.transpile(
|
||||
args.sql,
|
||||
sql,
|
||||
read=args.read,
|
||||
write=args.write,
|
||||
identify=args.identify,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
"""
|
||||
.. include:: ./README.md
|
||||
"""
|
|
@ -9,18 +9,8 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.dataframe.sql.column import Column
|
||||
from sqlglot.dataframe.sql.types import StructType
|
||||
|
||||
ColumnLiterals = t.TypeVar(
|
||||
"ColumnLiterals",
|
||||
bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||
)
|
||||
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
|
||||
ColumnOrLiteral = t.TypeVar(
|
||||
"ColumnOrLiteral",
|
||||
bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
|
||||
)
|
||||
SchemaInput = t.TypeVar(
|
||||
"SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
|
||||
)
|
||||
OutputExpressionContainer = t.TypeVar(
|
||||
"OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
|
||||
)
|
||||
ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
ColumnOrName = t.Union[Column, str]
|
||||
ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
|
||||
SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]]
|
||||
OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
|
||||
|
|
|
@ -634,7 +634,7 @@ class DataFrame:
|
|||
all_columns = self._get_outer_select_columns(new_df.expression)
|
||||
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
||||
if isinstance(value, dict):
|
||||
values = value.values()
|
||||
values = list(value.values())
|
||||
columns = self._ensure_and_normalize_cols(list(value))
|
||||
if not columns:
|
||||
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
"""Supports BigQuery Standard SQL."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
inline_array_sql,
|
||||
no_ilike_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -120,13 +124,12 @@ class BigQuery(Dialect):
|
|||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"QUALIFY": TokenType.QUALIFY,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
"WINDOW": TokenType.WINDOW,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_TRUNC": _date_trunc,
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
|
||||
|
@ -144,31 +147,33 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
NO_PAREN_FUNCTIONS = {
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS,
|
||||
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
|
||||
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
|
||||
TokenType.CURRENT_TIME: exp.CurrentTime,
|
||||
}
|
||||
|
||||
NESTED_TYPE_TOKENS = {
|
||||
*parser.Parser.NESTED_TYPE_TOKENS,
|
||||
*parser.Parser.NESTED_TYPE_TOKENS, # type: ignore
|
||||
TokenType.TABLE,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.DateAdd: _date_add_sql("DATE", "ADD"),
|
||||
exp.DateSub: _date_add_sql("DATE", "SUB"),
|
||||
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
|
||||
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
|
||||
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
|
@ -176,6 +181,7 @@ class BigQuery(Dialect):
|
|||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
|
||||
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.VariancePop: rename_func("VAR_POP"),
|
||||
exp.Values: _derived_table_values_to_unnest,
|
||||
exp.ReturnsProperty: _returnsproperty_sql,
|
||||
|
@ -188,7 +194,7 @@ class BigQuery(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
exp.DataType.Type.INT: "INT64",
|
||||
|
|
|
@ -35,13 +35,13 @@ class ClickHouse(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"MAP": parse_var_map,
|
||||
}
|
||||
|
||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF}
|
||||
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
|
||||
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY}
|
||||
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
|
||||
|
||||
def _parse_table(self, schema=False):
|
||||
this = super()._parse_table(schema)
|
||||
|
@ -55,7 +55,7 @@ class ClickHouse(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.NULLABLE: "Nullable",
|
||||
exp.DataType.Type.DATETIME: "DateTime64",
|
||||
exp.DataType.Type.MAP: "Map",
|
||||
|
@ -70,7 +70,7 @@ class ClickHouse(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.Array: inline_array_sql,
|
||||
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
|
||||
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
|
||||
|
|
|
@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect):
|
|||
def rename_func(name):
|
||||
def _rename(self, expression):
|
||||
args = flatten(expression.args.values())
|
||||
return f"{name}({self.format_args(*args)})"
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
|
||||
return _rename
|
||||
|
||||
|
@ -217,11 +217,11 @@ def if_sql(self, expression):
|
|||
|
||||
|
||||
def arrow_json_extract_sql(self, expression):
|
||||
return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
|
||||
return self.binary(expression, "->")
|
||||
|
||||
|
||||
def arrow_json_extract_scalar_sql(self, expression):
|
||||
return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
|
||||
return self.binary(expression, "->>")
|
||||
|
||||
|
||||
def inline_array_sql(self, expression):
|
||||
|
@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression):
|
|||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
)
|
||||
return f"LOCATE({args})"
|
||||
|
||||
|
||||
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||
|
||||
|
||||
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
|
|
@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
create_with_partitions_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_pivot_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.dialects.postgres import _lateral_sql
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
|
@ -117,14 +118,14 @@ class Drill(Dialect):
|
|||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
|
@ -139,14 +140,13 @@ class Drill(Dialect):
|
|||
ROOT_PROPERTIES = {exp.PartitionedByProperty}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.Lateral: _lateral_sql,
|
||||
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||
|
@ -160,7 +160,7 @@ class Drill(Dialect):
|
|||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
|
|
@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
|
|||
approx_count_distinct_sql,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
no_pivot_sql,
|
||||
no_properties_sql,
|
||||
|
@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import (
|
|||
no_tablesample_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -83,11 +85,12 @@ class DuckDB(Dialect):
|
|||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
":=": TokenType.EQ,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
"ARRAY_SORT": exp.SortArray.from_arg_list,
|
||||
|
@ -119,16 +122,18 @@ class DuckDB(Dialect):
|
|||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
exp.Array: rename_func("LIST_VALUE"),
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else rename_func("LIST_VALUE")(self, e),
|
||||
exp.ArraySize: rename_func("ARRAY_LENGTH"),
|
||||
exp.ArraySort: _array_sort_sql,
|
||||
exp.ArraySum: rename_func("LIST_SUM"),
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.DateAdd: _date_add,
|
||||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""",
|
||||
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
|
@ -136,6 +141,7 @@ class DuckDB(Dialect):
|
|||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: arrow_json_extract_sql,
|
||||
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.LogicalOr: rename_func("BOOL_OR"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Properties: no_properties_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
|
@ -150,7 +156,7 @@ class DuckDB(Dialect):
|
|||
exp.Struct: _struct_pack_sql,
|
||||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
|
||||
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("EPOCH"),
|
||||
|
@ -163,7 +169,7 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.VARCHAR: "TEXT",
|
||||
exp.DataType.Type.NVARCHAR: "TEXT",
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
strposition_to_local_sql,
|
||||
struct_extract_sql,
|
||||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
|
@ -197,7 +198,7 @@ class Hive(Dialect):
|
|||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
|
@ -217,7 +218,12 @@ class Hive(Dialect):
|
|||
),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
|
||||
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
|
||||
[
|
||||
exp.TimeStrToTime(this=seq_get(args, 0)),
|
||||
seq_get(args, 1),
|
||||
]
|
||||
),
|
||||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
|
@ -240,7 +246,7 @@ class Hive(Dialect):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
|
||||
expressions=self._parse_wrapped_csv(self._parse_property)
|
||||
),
|
||||
|
@ -248,14 +254,14 @@ class Hive(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
exp.DataType.Type.VARBINARY: "BINARY",
|
||||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.Property: _property_sql,
|
||||
exp.ApproxDistinct: approx_count_distinct_sql,
|
||||
|
@ -294,7 +300,7 @@ class Hive(Dialect):
|
|||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
|
||||
exp.TimeStrToDate: rename_func("TO_DATE"),
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: _time_to_str,
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
|
|
|
@ -161,8 +161,6 @@ class MySQL(Dialect):
|
|||
"_UCS2": TokenType.INTRODUCER,
|
||||
"_UJIS": TokenType.INTRODUCER,
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
|
||||
"N": TokenType.INTRODUCER,
|
||||
"n": TokenType.INTRODUCER,
|
||||
"_UTF8": TokenType.INTRODUCER,
|
||||
"_UTF16": TokenType.INTRODUCER,
|
||||
"_UTF16LE": TokenType.INTRODUCER,
|
||||
|
@ -175,10 +173,10 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
|
@ -190,7 +188,7 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"GROUP_CONCAT": lambda self: self.expression(
|
||||
exp.GroupConcat,
|
||||
this=self._parse_lambda(),
|
||||
|
@ -199,12 +197,12 @@ class MySQL(Dialect):
|
|||
}
|
||||
|
||||
PROPERTY_PARSERS = {
|
||||
**parser.Parser.PROPERTY_PARSERS,
|
||||
**parser.Parser.PROPERTY_PARSERS, # type: ignore
|
||||
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
**parser.Parser.STATEMENT_PARSERS,
|
||||
**parser.Parser.STATEMENT_PARSERS, # type: ignore
|
||||
TokenType.SHOW: lambda self: self._parse_show(),
|
||||
TokenType.SET: lambda self: self._parse_set(),
|
||||
}
|
||||
|
@ -429,7 +427,7 @@ class MySQL(Dialect):
|
|||
NULL_ORDERING_SUPPORTED = False
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.ILike: no_ilike_sql,
|
||||
|
|
|
@ -39,13 +39,13 @@ class Oracle(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"DECODE": exp.Matches.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "NUMBER",
|
||||
exp.DataType.Type.SMALLINT: "NUMBER",
|
||||
exp.DataType.Type.INT: "NUMBER",
|
||||
|
@ -60,7 +60,7 @@ class Oracle(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Limit: _limit_sql,
|
||||
|
|
|
@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
|
|||
no_trycast_sql,
|
||||
str_position_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
from sqlglot.transforms import delegate, preprocess
|
||||
|
||||
DATE_DIFF_FACTOR = {
|
||||
"MICROSECOND": " * 1000000",
|
||||
"MILLISECOND": " * 1000",
|
||||
"SECOND": "",
|
||||
"MINUTE": " / 60",
|
||||
"HOUR": " / 3600",
|
||||
"DAY": " / 86400",
|
||||
}
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
|
@ -34,16 +44,30 @@ def _date_add_sql(kind):
|
|||
return func
|
||||
|
||||
|
||||
def _lateral_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
if isinstance(expression.this, exp.Subquery):
|
||||
return f"LATERAL{self.sep()}{this}"
|
||||
alias = expression.args["alias"]
|
||||
table = alias.name
|
||||
table = f" {table}" if table else table
|
||||
columns = self.expressions(alias, key="columns", flat=True)
|
||||
columns = f" AS {columns}" if columns else ""
|
||||
return f"LATERAL{self.sep()}{this}{table}{columns}"
|
||||
def _date_diff_sql(self, expression):
|
||||
unit = expression.text("unit").upper()
|
||||
factor = DATE_DIFF_FACTOR.get(unit)
|
||||
|
||||
end = f"CAST({expression.this} AS TIMESTAMP)"
|
||||
start = f"CAST({expression.expression} AS TIMESTAMP)"
|
||||
|
||||
if factor is not None:
|
||||
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
|
||||
|
||||
age = f"AGE({end}, {start})"
|
||||
|
||||
if unit == "WEEK":
|
||||
extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
|
||||
elif unit == "MONTH":
|
||||
extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
|
||||
elif unit == "QUARTER":
|
||||
extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
|
||||
elif unit == "YEAR":
|
||||
extract = f"EXTRACT(year FROM {age})"
|
||||
else:
|
||||
self.unsupported(f"Unsupported DATEDIFF unit {unit}")
|
||||
|
||||
return f"CAST({extract} AS BIGINT)"
|
||||
|
||||
|
||||
def _substring_sql(self, expression):
|
||||
|
@ -141,7 +165,7 @@ def _serial_to_generated(expression):
|
|||
|
||||
def _to_timestamp(args):
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1 and args[0].is_number:
|
||||
if len(args) == 1:
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
# https://www.postgresql.org/docs/current/functions-formatting.html
|
||||
|
@ -211,11 +235,16 @@ class Postgres(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"~~": TokenType.LIKE,
|
||||
"~~*": TokenType.ILIKE,
|
||||
"~*": TokenType.IRLIKE,
|
||||
"~": TokenType.RLIKE,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BIGSERIAL": TokenType.BIGSERIAL,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"CHARACTER VARYING": TokenType.VARCHAR,
|
||||
"COMMENT ON": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
|
@ -233,6 +262,7 @@ class Postgres(Dialect):
|
|||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"UUID": TokenType.UUID,
|
||||
"CSTRING": TokenType.PSEUDO_TYPE,
|
||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
}
|
||||
|
@ -244,17 +274,16 @@ class Postgres(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
LATERAL_FUNCTION_AS_VIEW = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TINYINT: "SMALLINT",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
|
||||
|
@ -264,7 +293,7 @@ class Postgres(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ColumnDef: preprocess(
|
||||
[
|
||||
_auto_increment_to_serial,
|
||||
|
@ -274,13 +303,16 @@ class Postgres(Dialect):
|
|||
),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
|
||||
exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
|
||||
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
|
||||
exp.CurrentDate: no_paren_current_date_sql,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: _date_add_sql("+"),
|
||||
exp.DateSub: _date_add_sql("-"),
|
||||
exp.Lateral: _lateral_sql,
|
||||
exp.DateDiff: _date_diff_sql,
|
||||
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
|
||||
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Substring: _substring_sql,
|
||||
|
@ -291,5 +323,7 @@ class Postgres(Dialect):
|
|||
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.GroupConcat: _string_agg_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
|
||||
if isinstance(seq_get(e.expressions, 0), exp.Select)
|
||||
else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
|
|||
rename_func,
|
||||
str_position_sql,
|
||||
struct_extract_sql,
|
||||
timestrtotime_sql,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.errors import UnsupportedError
|
||||
|
@ -38,10 +39,6 @@ def _datatype_sql(self, expression):
|
|||
return sql
|
||||
|
||||
|
||||
def _date_parse_sql(self, expression):
|
||||
return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
|
||||
|
||||
|
||||
def _explode_to_unnest_sql(self, expression):
|
||||
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
|
||||
return self.sql(
|
||||
|
@ -137,7 +134,7 @@ class Presto(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
|
||||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
|
@ -174,7 +171,7 @@ class Presto(Dialect):
|
|||
ROOT_PROPERTIES = {exp.SchemaCommentProperty}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.FLOAT: "REAL",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
|
@ -184,7 +181,7 @@ class Presto(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
**transforms.UNALIAS_GROUP, # type: ignore
|
||||
exp.ApproxDistinct: _approx_distinct_sql,
|
||||
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
|
||||
|
@ -224,8 +221,8 @@ class Presto(Dialect):
|
|||
exp.StructExtract: struct_extract_sql,
|
||||
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.TimeStrToDate: _date_parse_sql,
|
||||
exp.TimeStrToTime: _date_parse_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
|
||||
|
|
|
@ -36,7 +36,6 @@ class Redshift(Postgres):
|
|||
"TIMETZ": TokenType.TIMESTAMPTZ,
|
||||
"UNLOAD": TokenType.COMMAND,
|
||||
"VARBYTE": TokenType.VARBINARY,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
}
|
||||
|
||||
class Generator(Postgres.Generator):
|
||||
|
|
|
@ -3,13 +3,15 @@ from __future__ import annotations
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
rename_func,
|
||||
timestrtotime_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.expressions import Literal
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.helper import flatten, seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
|
@ -183,7 +185,7 @@ class Snowflake(Dialect):
|
|||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", "$$"]
|
||||
ESCAPES = ["\\"]
|
||||
ESCAPES = ["\\", "'"]
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
@ -206,9 +208,10 @@ class Snowflake(Dialect):
|
|||
CREATE_TRANSIENT = True
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.Array: inline_array_sql,
|
||||
exp.ArrayConcat: rename_func("ARRAY_CAT"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DataType: _datatype_sql,
|
||||
exp.If: rename_func("IFF"),
|
||||
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
|
||||
|
@ -218,13 +221,14 @@ class Snowflake(Dialect):
|
|||
exp.Matches: rename_func("DECODE"),
|
||||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
exp.UnixToTime: _unix_to_time_sql,
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
|
||||
}
|
||||
|
||||
|
@ -246,3 +250,47 @@ class Snowflake(Dialect):
|
|||
if not expression.args.get("distinct", False):
|
||||
self.unsupported("INTERSECT with All is not supported in Snowflake")
|
||||
return super().intersect_op(expression)
|
||||
|
||||
def values_sql(self, expression: exp.Values) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted.
|
||||
|
||||
We also want to make sure that after we find matches where we need to unquote a column that we prevent users
|
||||
from adding quotes to the column by using the `identify` argument when generating the SQL.
|
||||
"""
|
||||
alias = expression.args.get("alias")
|
||||
if alias and alias.args.get("columns"):
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier)
|
||||
and isinstance(node.parent, exp.TableAlias)
|
||||
and node.arg_key == "columns"
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
|
||||
return super().values_sql(expression)
|
||||
|
||||
def select_sql(self, expression: exp.Select) -> str:
|
||||
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
|
||||
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
|
||||
to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when
|
||||
generating the SQL.
|
||||
|
||||
Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the
|
||||
expression. This might not be true in a case where the same column name can be sourced from another table that can
|
||||
properly quote but should be true in most cases.
|
||||
"""
|
||||
values_expressions = expression.find_all(exp.Values)
|
||||
values_identifiers = set(
|
||||
flatten(
|
||||
v.args.get("alias", exp.Alias()).args.get("columns", [])
|
||||
for v in values_expressions
|
||||
)
|
||||
)
|
||||
if values_identifiers:
|
||||
expression = expression.transform(
|
||||
lambda node: exp.Identifier(**{**node.args, "quoted": False})
|
||||
if isinstance(node, exp.Identifier) and node in values_identifiers
|
||||
else node,
|
||||
)
|
||||
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
|
||||
return super().select_sql(expression)
|
||||
|
|
|
@ -76,7 +76,7 @@ class Spark(Hive):
|
|||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
**parser.Parser.FUNCTION_PARSERS,
|
||||
**parser.Parser.FUNCTION_PARSERS, # type: ignore
|
||||
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
|
||||
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
|
||||
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
|
||||
|
@ -87,6 +87,16 @@ class Spark(Hive):
|
|||
"SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
|
||||
}
|
||||
|
||||
def _parse_add_column(self):
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self):
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
)
|
||||
|
||||
class Generator(Hive.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**Hive.Generator.TYPE_MAPPING, # type: ignore
|
||||
|
|
|
@ -42,13 +42,13 @@ class SQLite(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"EDITDIST3": exp.Levenshtein.from_arg_list,
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BOOLEAN: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
|
@ -70,7 +70,7 @@ class SQLite(Dialect):
|
|||
}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
**generator.Generator.TRANSFORMS, # type: ignore
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
|
|
|
@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL
|
|||
class StarRocks(MySQL):
|
||||
class Generator(MySQL.Generator): # type: ignore
|
||||
TYPE_MAPPING = {
|
||||
**MySQL.Generator.TYPE_MAPPING,
|
||||
**MySQL.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
|
||||
|
|
|
@ -30,7 +30,7 @@ class Tableau(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"IFNULL": exp.Coalesce.from_arg_list,
|
||||
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
|
||||
}
|
||||
|
|
|
@ -224,11 +224,7 @@ class TSQL(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ['"', ("[", "]")]
|
||||
|
||||
QUOTES = [
|
||||
(prefix + quote, quote) if prefix else quote
|
||||
for quote in ["'", '"']
|
||||
for prefix in ["", "n", "N"]
|
||||
]
|
||||
QUOTES = ["'", '"']
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
@ -253,7 +249,7 @@ class TSQL(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
**parser.Parser.FUNCTIONS, # type: ignore
|
||||
"CHARINDEX": exp.StrPosition.from_arg_list,
|
||||
"ISNULL": exp.Coalesce.from_arg_list,
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
|
@ -314,7 +310,7 @@ class TSQL(Dialect):
|
|||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
**generator.Generator.TYPE_MAPPING, # type: ignore
|
||||
exp.DataType.Type.BOOLEAN: "BIT",
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.DECIMAL: "NUMERIC",
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
"""
|
||||
.. include:: ../posts/sql_diff.md
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
|
|
@ -29,10 +29,10 @@ class Context:
|
|||
self._table: t.Optional[Table] = None
|
||||
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
|
||||
self.row_readers = {name: table.reader for name, table in tables.items()}
|
||||
self.env = {**(env or {}), "scope": self.row_readers}
|
||||
self.env = {**ENV, **(env or {}), "scope": self.row_readers}
|
||||
|
||||
def eval(self, code):
|
||||
return eval(code, ENV, self.env)
|
||||
return eval(code, self.env)
|
||||
|
||||
def eval_tuple(self, codes):
|
||||
return tuple(self.eval(code) for code in codes)
|
||||
|
|
|
@ -127,14 +127,16 @@ def interval(this, unit):
|
|||
ENV = {
|
||||
"exp": exp,
|
||||
# aggs
|
||||
"SUM": filter_nulls(sum),
|
||||
"ARRAYAGG": list,
|
||||
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
|
||||
"COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
|
||||
"MAX": filter_nulls(max),
|
||||
"MIN": filter_nulls(min),
|
||||
"SUM": filter_nulls(sum),
|
||||
# scalar functions
|
||||
"ABS": null_if_any(lambda this: abs(this)),
|
||||
"ADD": null_if_any(lambda e, this: e + this),
|
||||
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
|
||||
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
|
||||
"BITWISEAND": null_if_any(lambda this, e: this & e),
|
||||
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
|
||||
|
|
|
@ -394,6 +394,18 @@ def _case_sql(self, expression):
|
|||
return chain
|
||||
|
||||
|
||||
def _lambda_sql(self, e: exp.Lambda) -> str:
|
||||
names = {e.name.lower() for e in e.expressions}
|
||||
|
||||
e = e.transform(
|
||||
lambda n: exp.Var(this=n.name)
|
||||
if isinstance(n, exp.Identifier) and n.name.lower() in names
|
||||
else n
|
||||
)
|
||||
|
||||
return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}"
|
||||
|
||||
|
||||
class Python(Dialect):
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
ESCAPES = ["\\"]
|
||||
|
@ -414,6 +426,7 @@ class Python(Dialect):
|
|||
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
|
||||
exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})",
|
||||
exp.Is: lambda self, e: self.binary(e, "is"),
|
||||
exp.Lambda: _lambda_sql,
|
||||
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
|
||||
exp.Null: lambda *_: "None",
|
||||
exp.Or: lambda self, e: self.binary(e, "or"),
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
"""
|
||||
.. include:: ../pdoc/docs/expressions.md
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import math
|
||||
import numbers
|
||||
import re
|
||||
import typing as t
|
||||
|
@ -682,6 +687,10 @@ class CharacterSet(Expression):
|
|||
class With(Expression):
|
||||
arg_types = {"expressions": True, "recursive": False}
|
||||
|
||||
@property
|
||||
def recursive(self) -> bool:
|
||||
return bool(self.args.get("recursive"))
|
||||
|
||||
|
||||
class WithinGroup(Expression):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
@ -724,6 +733,18 @@ class ColumnDef(Expression):
|
|||
"this": True,
|
||||
"kind": True,
|
||||
"constraints": False,
|
||||
"exists": False,
|
||||
}
|
||||
|
||||
|
||||
class AlterColumn(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"dtype": False,
|
||||
"collate": False,
|
||||
"using": False,
|
||||
"default": False,
|
||||
"drop": False,
|
||||
}
|
||||
|
||||
|
||||
|
@ -877,6 +898,11 @@ class Introducer(Expression):
|
|||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
||||
# national char, like n'utf8'
|
||||
class National(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class LoadData(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -894,7 +920,7 @@ class Partition(Expression):
|
|||
|
||||
|
||||
class Fetch(Expression):
|
||||
arg_types = {"direction": False, "count": True}
|
||||
arg_types = {"direction": False, "count": False}
|
||||
|
||||
|
||||
class Group(Expression):
|
||||
|
@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = {
|
|||
"group": False,
|
||||
"having": False,
|
||||
"qualify": False,
|
||||
"window": False,
|
||||
"windows": False,
|
||||
"distribute": False,
|
||||
"sort": False,
|
||||
"cluster": False,
|
||||
|
@ -1353,7 +1379,7 @@ class Union(Subqueryable):
|
|||
|
||||
Example:
|
||||
>>> select("1").union(select("1")).limit(1).sql()
|
||||
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
|
||||
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
|
||||
|
||||
Args:
|
||||
expression (str | int | Expression): the SQL code string to parse.
|
||||
|
@ -1889,6 +1915,18 @@ class Select(Subqueryable):
|
|||
**opts,
|
||||
)
|
||||
|
||||
def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
|
||||
return _apply_list_builder(
|
||||
*expressions,
|
||||
instance=self,
|
||||
arg="windows",
|
||||
append=append,
|
||||
into=Window,
|
||||
dialect=dialect,
|
||||
copy=copy,
|
||||
**opts,
|
||||
)
|
||||
|
||||
def distinct(self, distinct=True, copy=True) -> Select:
|
||||
"""
|
||||
Set the OFFSET expression.
|
||||
|
@ -2140,6 +2178,11 @@ class DataType(Expression):
|
|||
)
|
||||
|
||||
|
||||
# https://www.postgresql.org/docs/15/datatype-pseudo.html
|
||||
class PseudoType(Expression):
|
||||
pass
|
||||
|
||||
|
||||
class StructKwarg(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -2167,18 +2210,26 @@ class Command(Expression):
|
|||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
||||
class Transaction(Command):
|
||||
class Transaction(Expression):
|
||||
arg_types = {"this": False, "modes": False}
|
||||
|
||||
|
||||
class Commit(Command):
|
||||
class Commit(Expression):
|
||||
arg_types = {"chain": False}
|
||||
|
||||
|
||||
class Rollback(Command):
|
||||
class Rollback(Expression):
|
||||
arg_types = {"savepoint": False}
|
||||
|
||||
|
||||
class AlterTable(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"actions": True,
|
||||
"exists": False,
|
||||
}
|
||||
|
||||
|
||||
# Binary expressions like (ADD a b)
|
||||
class Binary(Expression):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate):
|
|||
pass
|
||||
|
||||
|
||||
class Slice(Binary):
|
||||
arg_types = {"this": False, "expression": False}
|
||||
|
||||
|
||||
class Sub(Binary):
|
||||
pass
|
||||
|
||||
|
@ -2392,7 +2447,7 @@ class TimeUnit(Expression):
|
|||
|
||||
|
||||
class Interval(TimeUnit):
|
||||
arg_types = {"this": True, "unit": False}
|
||||
arg_types = {"this": False, "unit": False}
|
||||
|
||||
|
||||
class IgnoreNulls(Expression):
|
||||
|
@ -2730,8 +2785,11 @@ class Initcap(Func):
|
|||
pass
|
||||
|
||||
|
||||
class JSONExtract(Func):
|
||||
arg_types = {"this": True, "path": True}
|
||||
class JSONBContains(Binary):
|
||||
_sql_names = ["JSONB_CONTAINS"]
|
||||
|
||||
|
||||
class JSONExtract(Binary, Func):
|
||||
_sql_names = ["JSON_EXTRACT"]
|
||||
|
||||
|
||||
|
@ -2776,6 +2834,10 @@ class Log10(Func):
|
|||
pass
|
||||
|
||||
|
||||
class LogicalOr(AggFunc):
|
||||
_sql_names = ["LOGICAL_OR", "BOOL_OR"]
|
||||
|
||||
|
||||
class Lower(Func):
|
||||
_sql_names = ["LOWER", "LCASE"]
|
||||
|
||||
|
@ -2846,6 +2908,10 @@ class RegexpLike(Func):
|
|||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
||||
class RegexpILike(Func):
|
||||
arg_types = {"this": True, "expression": True, "flag": False}
|
||||
|
||||
|
||||
class RegexpSplit(Func):
|
||||
arg_types = {"this": True, "expression": True}
|
||||
|
||||
|
@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U
|
|||
],
|
||||
)
|
||||
if from_:
|
||||
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
|
||||
update.set(
|
||||
"from",
|
||||
maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts),
|
||||
)
|
||||
if isinstance(where, Condition):
|
||||
where = Where(this=where)
|
||||
if where:
|
||||
update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts))
|
||||
update.set(
|
||||
"where",
|
||||
maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts),
|
||||
)
|
||||
return update
|
||||
|
||||
|
||||
|
@ -3522,7 +3594,7 @@ def paren(expression) -> Paren:
|
|||
return Paren(this=expression)
|
||||
|
||||
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
|
||||
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||
|
||||
|
||||
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
||||
|
@ -3724,6 +3796,8 @@ def convert(value) -> Expression:
|
|||
return Boolean(this=value)
|
||||
if isinstance(value, str):
|
||||
return Literal.string(value)
|
||||
if isinstance(value, float) and math.isnan(value):
|
||||
return NULL
|
||||
if isinstance(value, numbers.Number):
|
||||
return Literal.number(value)
|
||||
if isinstance(value, tuple):
|
||||
|
@ -3732,11 +3806,13 @@ def convert(value) -> Expression:
|
|||
return Array(expressions=[convert(v) for v in value])
|
||||
if isinstance(value, dict):
|
||||
return Map(
|
||||
keys=[convert(k) for k in value.keys()],
|
||||
keys=[convert(k) for k in value],
|
||||
values=[convert(v) for v in value.values()],
|
||||
)
|
||||
if isinstance(value, datetime.datetime):
|
||||
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
|
||||
datetime_literal = Literal.string(
|
||||
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
|
||||
)
|
||||
return TimeStrToTime(this=datetime_literal)
|
||||
if isinstance(value, datetime.date):
|
||||
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
|
||||
|
|
|
@ -361,10 +361,11 @@ class Generator:
|
|||
column = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
|
||||
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
|
||||
|
||||
if not constraints:
|
||||
return f"{column} {kind}"
|
||||
return f"{column} {kind} {constraints}"
|
||||
return f"{exists}{column} {kind}"
|
||||
return f"{exists}{column} {kind} {constraints}"
|
||||
|
||||
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -549,6 +550,9 @@ class Generator:
|
|||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
return text
|
||||
|
||||
def national_sql(self, expression: exp.National) -> str:
|
||||
return f"N{self.sql(expression, 'this')}"
|
||||
|
||||
def partition_sql(self, expression: exp.Partition) -> str:
|
||||
keys = csv(
|
||||
*[
|
||||
|
@ -633,6 +637,9 @@ class Generator:
|
|||
def introducer_sql(self, expression: exp.Introducer) -> str:
|
||||
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
|
||||
|
||||
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
|
||||
return expression.name.upper()
|
||||
|
||||
def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
|
||||
fields = expression.args.get("fields")
|
||||
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
|
||||
|
@ -793,19 +800,17 @@ class Generator:
|
|||
if isinstance(expression.this, exp.Subquery):
|
||||
return f"LATERAL {this}"
|
||||
|
||||
alias = expression.args["alias"]
|
||||
table = alias.name
|
||||
columns = self.expressions(alias, key="columns", flat=True)
|
||||
|
||||
if expression.args.get("view"):
|
||||
table = f" {table}" if table else table
|
||||
alias = expression.args["alias"]
|
||||
columns = self.expressions(alias, key="columns", flat=True)
|
||||
table = f" {alias.name}" if alias.name else ""
|
||||
columns = f" AS {columns}" if columns else ""
|
||||
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
|
||||
return f"{op_sql}{self.sep()}{this}{table}{columns}"
|
||||
|
||||
table = f" AS {table}" if table else table
|
||||
columns = f"({columns})" if columns else ""
|
||||
return f"LATERAL {this}{table}{columns}"
|
||||
alias = self.sql(expression, "alias")
|
||||
alias = f" AS {alias}" if alias else ""
|
||||
return f"LATERAL {this}{alias}"
|
||||
|
||||
def limit_sql(self, expression: exp.Limit) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -891,13 +896,15 @@ class Generator:
|
|||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
return csv(
|
||||
*sqls,
|
||||
*[self.sql(sql) for sql in expression.args.get("joins", [])],
|
||||
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
|
||||
*[self.sql(sql) for sql in expression.args.get("joins") or []],
|
||||
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
|
||||
self.sql(expression, "where"),
|
||||
self.sql(expression, "group"),
|
||||
self.sql(expression, "having"),
|
||||
self.sql(expression, "qualify"),
|
||||
self.sql(expression, "window"),
|
||||
self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
|
||||
if expression.args.get("windows")
|
||||
else "",
|
||||
self.sql(expression, "distribute"),
|
||||
self.sql(expression, "sort"),
|
||||
self.sql(expression, "cluster"),
|
||||
|
@ -1008,11 +1015,7 @@ class Generator:
|
|||
spec_sql = " " + self.window_spec_sql(spec) if spec else ""
|
||||
|
||||
alias = self.sql(expression, "alias")
|
||||
|
||||
if expression.arg_key == "window":
|
||||
this = this = f"{self.seg('WINDOW')} {this} AS"
|
||||
else:
|
||||
this = f"{this} OVER"
|
||||
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
|
||||
|
||||
if not partition and not order and not spec and alias:
|
||||
return f"{this} {alias}"
|
||||
|
@ -1141,9 +1144,11 @@ class Generator:
|
|||
return f"(SELECT {self.sql(unnest)})"
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f" {this}" if this else ""
|
||||
unit = self.sql(expression, "unit")
|
||||
unit = f" {unit}" if unit else ""
|
||||
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
|
||||
return f"INTERVAL{this}{unit}"
|
||||
|
||||
def reference_sql(self, expression: exp.Reference) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1245,6 +1250,43 @@ class Generator:
|
|||
savepoint = f" TO {savepoint}" if savepoint else ""
|
||||
return f"ROLLBACK{savepoint}"
|
||||
|
||||
def altercolumn_sql(self, expression: exp.AlterColumn) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
||||
dtype = self.sql(expression, "dtype")
|
||||
if dtype:
|
||||
collate = self.sql(expression, "collate")
|
||||
collate = f" COLLATE {collate}" if collate else ""
|
||||
using = self.sql(expression, "using")
|
||||
using = f" USING {using}" if using else ""
|
||||
return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
|
||||
|
||||
default = self.sql(expression, "default")
|
||||
if default:
|
||||
return f"ALTER COLUMN {this} SET DEFAULT {default}"
|
||||
|
||||
if not expression.args.get("drop"):
|
||||
self.unsupported("Unsupported ALTER COLUMN syntax")
|
||||
|
||||
return f"ALTER COLUMN {this} DROP DEFAULT"
|
||||
|
||||
def altertable_sql(self, expression: exp.AlterTable) -> str:
|
||||
actions = expression.args["actions"]
|
||||
|
||||
if isinstance(actions[0], exp.ColumnDef):
|
||||
actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
|
||||
elif isinstance(actions[0], exp.Schema):
|
||||
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
|
||||
elif isinstance(actions[0], exp.Drop):
|
||||
actions = self.expressions(expression, "actions")
|
||||
elif isinstance(actions[0], exp.AlterColumn):
|
||||
actions = self.sql(actions[0])
|
||||
else:
|
||||
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
|
||||
|
||||
exists = " IF EXISTS" if expression.args.get("exists") else ""
|
||||
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
|
||||
|
||||
def distinct_sql(self, expression: exp.Distinct) -> str:
|
||||
this = self.expressions(expression, flat=True)
|
||||
this = f" {this}" if this else ""
|
||||
|
@ -1327,6 +1369,9 @@ class Generator:
|
|||
def or_sql(self, expression: exp.Or) -> str:
|
||||
return self.connector_sql(expression, "OR")
|
||||
|
||||
def slice_sql(self, expression: exp.Slice) -> str:
|
||||
return self.binary(expression, ":")
|
||||
|
||||
def sub_sql(self, expression: exp.Sub) -> str:
|
||||
return self.binary(expression, "-")
|
||||
|
||||
|
@ -1369,6 +1414,7 @@ class Generator:
|
|||
flat: bool = False,
|
||||
indent: bool = True,
|
||||
sep: str = ", ",
|
||||
prefix: str = "",
|
||||
) -> str:
|
||||
expressions = expression.args.get(key or "expressions")
|
||||
|
||||
|
@ -1391,11 +1437,13 @@ class Generator:
|
|||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
|
||||
else:
|
||||
result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
|
||||
result_sqls.append(
|
||||
f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}"
|
||||
)
|
||||
else:
|
||||
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
|
||||
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
|
||||
|
||||
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
|
||||
return self.indent(result_sql, skip_first=False) if indent else result_sql
|
||||
|
|
|
@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
expression = coerce_type(expression)
|
||||
expression = remove_redundant_casts(expression)
|
||||
|
||||
if isinstance(expression, exp.Identifier):
|
||||
expression.set("quoted", True)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
|
|
|
@ -129,10 +129,23 @@ def join_condition(join):
|
|||
"""
|
||||
name = join.this.alias_or_name
|
||||
on = (join.args.get("on") or exp.true()).copy()
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
||||
def extract_condition(condition):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.true())
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.true())
|
||||
|
||||
# find the join keys
|
||||
# SELECT
|
||||
# FROM x
|
||||
|
@ -141,20 +154,30 @@ def join_condition(join):
|
|||
#
|
||||
# should pull y.b as the join key and x.a as the source key
|
||||
if normalized(on):
|
||||
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
|
||||
|
||||
for condition in on.flatten():
|
||||
if isinstance(condition, exp.EQ):
|
||||
left, right = condition.unnest_operands()
|
||||
left_tables = exp.column_table_names(left)
|
||||
right_tables = exp.column_table_names(right)
|
||||
extract_condition(condition)
|
||||
elif normalized(on, dnf=True):
|
||||
conditions = None
|
||||
|
||||
if name in left_tables and name not in right_tables:
|
||||
join_key.append(left)
|
||||
source_key.append(right)
|
||||
condition.replace(exp.true())
|
||||
elif name in right_tables and name not in left_tables:
|
||||
join_key.append(right)
|
||||
source_key.append(left)
|
||||
condition.replace(exp.true())
|
||||
for condition in on.flatten():
|
||||
parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)]
|
||||
if conditions is None:
|
||||
conditions = parts
|
||||
else:
|
||||
temp = []
|
||||
for p in parts:
|
||||
cs = [c for c in conditions if p == c]
|
||||
|
||||
if cs:
|
||||
temp.append(p)
|
||||
temp.extend(cs)
|
||||
conditions = temp
|
||||
|
||||
for condition in conditions:
|
||||
extract_condition(condition)
|
||||
|
||||
on = simplify(on)
|
||||
remaining_condition = None if on == exp.true() else on
|
||||
|
|
|
@ -58,7 +58,9 @@ def eliminate_subqueries(expression):
|
|||
existing_ctes = {}
|
||||
|
||||
with_ = root.expression.args.get("with")
|
||||
recursive = False
|
||||
if with_:
|
||||
recursive = with_.args.get("recursive")
|
||||
for cte in with_.expressions:
|
||||
existing_ctes[cte.this] = cte.alias
|
||||
new_ctes = []
|
||||
|
@ -88,7 +90,7 @@ def eliminate_subqueries(expression):
|
|||
new_ctes.append(new_cte)
|
||||
|
||||
if new_ctes:
|
||||
expression.set("with", exp.With(expressions=new_ctes))
|
||||
expression.set("with", exp.With(expressions=new_ctes, recursive=recursive))
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf):
|
|||
left, right = expression.args.values()
|
||||
|
||||
if isinstance(expression, exp.And if dnf else exp.Or):
|
||||
x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)]
|
||||
return x
|
||||
return [
|
||||
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
|
||||
]
|
||||
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
|||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||
from sqlglot.optimizer.quote_identities import quote_identities
|
||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||
|
||||
RULES = (
|
||||
|
@ -34,7 +33,6 @@ RULES = (
|
|||
eliminate_ctes,
|
||||
annotate_types,
|
||||
canonicalize,
|
||||
quote_identities,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,14 @@ def pushdown_predicates(expression):
|
|||
select = scope.expression
|
||||
where = select.args.get("where")
|
||||
if where:
|
||||
pushdown(where.this, scope.selected_sources, scope_ref_count)
|
||||
selected_sources = scope.selected_sources
|
||||
# a right join can only push down to itself and not the source FROM table
|
||||
for k, (node, source) in selected_sources.items():
|
||||
parent = node.find_ancestor(exp.Join, exp.From)
|
||||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
|
@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
|
||||
# a node can reference a CTE which should be pushed down
|
||||
if isinstance(node, exp.From) and not isinstance(source, exp.Table):
|
||||
with_ = source.parent.expression.args.get("with")
|
||||
if with_ and with_.recursive:
|
||||
return {}
|
||||
node = source.expression
|
||||
|
||||
if isinstance(node, exp.Join):
|
||||
if node.side:
|
||||
if node.side and node.side != "RIGHT":
|
||||
return {}
|
||||
nodes[table] = node
|
||||
elif isinstance(node, exp.Select) and len(tables) == 1:
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
|||
# Sentinel value that means an outer query selecting ALL columns
|
||||
SELECT_ALL = object()
|
||||
|
||||
# SELECTION TO USE IF SELECTION LIST IS EMPTY
|
||||
# Selection to use if selection list is empty
|
||||
DEFAULT_SELECTION = alias("1", "_")
|
||||
|
||||
|
||||
|
@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections):
|
|||
|
||||
# If there are no remaining selections, just select a single constant
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
|
||||
scope.expression.set("expressions", new_selections)
|
||||
return removed_indexes
|
||||
|
@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
|
|||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||
]
|
||||
if not new_selections:
|
||||
new_selections.append(DEFAULT_SELECTION)
|
||||
new_selections.append(DEFAULT_SELECTION.copy())
|
||||
scope.expression.set("expressions", new_selections)
|
||||
|
|
|
@ -311,6 +311,9 @@ def _qualify_outputs(scope):
|
|||
alias_ = alias(exp.column(""), alias=selection.name)
|
||||
alias_.set("this", selection)
|
||||
selection = alias_
|
||||
elif isinstance(selection, exp.Subquery):
|
||||
if not selection.alias:
|
||||
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
|
||||
elif not isinstance(selection, exp.Alias):
|
||||
alias_ = alias(exp.column(""), f"_col_{i}")
|
||||
alias_.set("this", selection)
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
from sqlglot import exp
|
||||
|
||||
|
||||
def quote_identities(expression):
|
||||
"""
|
||||
Rewrite sqlglot AST to ensure all identities are quoted.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
|
||||
>>> quote_identities(expression).sql()
|
||||
'SELECT "x"."a" AS "a" FROM "db"."x"'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to quote
|
||||
Returns:
|
||||
sqlglot.Expression: quoted expression
|
||||
"""
|
||||
|
||||
def qualify(node):
|
||||
if isinstance(node, exp.Identifier):
|
||||
node.set("quoted", True)
|
||||
return node
|
||||
|
||||
return expression.transform(qualify, copy=False)
|
|
@ -511,9 +511,20 @@ def _traverse_union(scope):
|
|||
|
||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||
sources = {}
|
||||
is_cte = scope_type == ScopeType.CTE
|
||||
|
||||
for derived_table in derived_tables:
|
||||
top = None
|
||||
recursive_scope = None
|
||||
|
||||
# if the scope is a recursive cte, it must be in the form of
|
||||
# base_case UNION recursive. thus the recursive scope is the first
|
||||
# section of the union.
|
||||
if is_cte and scope.expression.args["with"].recursive:
|
||||
union = derived_table.this
|
||||
|
||||
if isinstance(union, exp.Union):
|
||||
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
|
||||
|
||||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
|
||||
|
@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type):
|
|||
)
|
||||
):
|
||||
yield child_scope
|
||||
top = child_scope
|
||||
|
||||
# Tables without aliases will be set as ""
|
||||
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
|
||||
# Until then, this means that only a single, unaliased derived table is allowed (rather,
|
||||
# the latest one wins.
|
||||
sources[derived_table.alias] = child_scope
|
||||
if scope_type == ScopeType.CTE:
|
||||
scope.cte_scopes.append(top)
|
||||
alias = derived_table.alias
|
||||
sources[alias] = child_scope
|
||||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
|
||||
# append the final child_scope yielded
|
||||
if is_cte:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
else:
|
||||
scope.derived_table_scopes.append(top)
|
||||
scope.derived_table_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ def unnest_subqueries(expression):
|
|||
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
|
||||
>>> unnest_subqueries(expression).sql()
|
||||
'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
|
||||
AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
|
||||
AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): expression to unnest
|
||||
|
@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
table_alias = _alias(sequence)
|
||||
keys = []
|
||||
|
||||
# for all external columns in the where statement,
|
||||
# split out the relevant data to convert it into a join
|
||||
# for all external columns in the where statement, find the relevant predicate
|
||||
# keys to convert it into a join
|
||||
for column in external_columns:
|
||||
if column.find_ancestor(exp.Where) is not where:
|
||||
return
|
||||
|
@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
|
||||
return
|
||||
|
||||
is_subquery_projection = any(
|
||||
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
|
||||
)
|
||||
|
||||
value = select.selects[0]
|
||||
key_aliases = {}
|
||||
group_by = []
|
||||
|
@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
parent_predicate = select.find_ancestor(exp.Predicate)
|
||||
|
||||
# if the value of the subquery is not an agg or a key, we need to collect it into an array
|
||||
# so that it can be grouped
|
||||
# so that it can be grouped. For subquery projections, we use a MAX aggregation instead.
|
||||
agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg
|
||||
if not value.find(exp.AggFunc) and value.this not in group_by:
|
||||
select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False)
|
||||
select.select(
|
||||
exp.alias_(agg_func(this=value.this), value.alias, quoted=False),
|
||||
append=False,
|
||||
copy=False,
|
||||
)
|
||||
|
||||
# exists queries should not have any selects as it only checks if there are any rows
|
||||
# all selects will be added by the optimizer and only used for join keys
|
||||
|
@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
if isinstance(parent_predicate, exp.Exists) or key != value.this:
|
||||
select.select(f"{key} AS {alias}", copy=False)
|
||||
else:
|
||||
select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
|
||||
select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False)
|
||||
|
||||
alias = exp.column(value.alias, table_alias)
|
||||
other = _other_operand(parent_predicate)
|
||||
|
@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence):
|
|||
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
|
||||
)
|
||||
else:
|
||||
if is_subquery_projection:
|
||||
alias = exp.alias_(alias, select.parent.alias)
|
||||
select.parent.replace(alias)
|
||||
|
||||
for key, column, predicate in keys:
|
||||
predicate.replace(exp.true())
|
||||
nested = exp.column(key_aliases[key], table_alias)
|
||||
|
||||
if is_subquery_projection:
|
||||
key.replace(nested)
|
||||
continue
|
||||
|
||||
if key in group_by:
|
||||
key.replace(nested)
|
||||
parent_predicate = _replace(
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing as t
|
|||
|
||||
from sqlglot import exp
|
||||
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
|
||||
from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
|
||||
from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
|
||||
from sqlglot.tokens import Token, Tokenizer, TokenType
|
||||
from sqlglot.trie import in_trie, new_trie
|
||||
|
||||
|
@ -117,6 +117,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.GEOMETRY,
|
||||
TokenType.HLLSKETCH,
|
||||
TokenType.HSTORE,
|
||||
TokenType.PSEUDO_TYPE,
|
||||
TokenType.SUPER,
|
||||
TokenType.SERIAL,
|
||||
TokenType.SMALLSERIAL,
|
||||
|
@ -153,6 +154,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.CACHE,
|
||||
TokenType.CASCADE,
|
||||
TokenType.COLLATE,
|
||||
TokenType.COLUMN,
|
||||
TokenType.COMMAND,
|
||||
TokenType.COMMIT,
|
||||
TokenType.COMPOUND,
|
||||
|
@ -169,6 +171,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ESCAPE,
|
||||
TokenType.FALSE,
|
||||
TokenType.FIRST,
|
||||
TokenType.FILTER,
|
||||
TokenType.FOLLOWING,
|
||||
TokenType.FORMAT,
|
||||
TokenType.FUNCTION,
|
||||
|
@ -188,6 +191,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.MERGE,
|
||||
TokenType.NATURAL,
|
||||
TokenType.NEXT,
|
||||
TokenType.OFFSET,
|
||||
TokenType.ONLY,
|
||||
TokenType.OPTIONS,
|
||||
TokenType.ORDINALITY,
|
||||
|
@ -222,12 +226,18 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.PROPERTIES,
|
||||
TokenType.PROCEDURE,
|
||||
TokenType.VOLATILE,
|
||||
TokenType.WINDOW,
|
||||
*SUBQUERY_PREDICATES,
|
||||
*TYPE_TOKENS,
|
||||
*NO_PAREN_FUNCTIONS,
|
||||
}
|
||||
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
|
||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||
TokenType.APPLY,
|
||||
TokenType.NATURAL,
|
||||
TokenType.OFFSET,
|
||||
TokenType.WINDOW,
|
||||
}
|
||||
|
||||
UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
|
||||
|
||||
|
@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.TABLE,
|
||||
TokenType.TIMESTAMP,
|
||||
TokenType.TIMESTAMPTZ,
|
||||
TokenType.WINDOW,
|
||||
*TYPE_TOKENS,
|
||||
*SUBQUERY_PREDICATES,
|
||||
}
|
||||
|
@ -351,22 +362,27 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONExtract,
|
||||
this=this,
|
||||
path=path,
|
||||
expression=path,
|
||||
),
|
||||
TokenType.DARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONExtractScalar,
|
||||
this=this,
|
||||
path=path,
|
||||
expression=path,
|
||||
),
|
||||
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONBExtract,
|
||||
this=this,
|
||||
path=path,
|
||||
expression=path,
|
||||
),
|
||||
TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
|
||||
exp.JSONBExtractScalar,
|
||||
this=this,
|
||||
path=path,
|
||||
expression=path,
|
||||
),
|
||||
TokenType.PLACEHOLDER: lambda self, this, key: self.expression(
|
||||
exp.JSONBContains,
|
||||
this=this,
|
||||
expression=key,
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -392,25 +408,27 @@ class Parser(metaclass=_Parser):
|
|||
exp.Ordered: lambda self: self._parse_ordered(),
|
||||
exp.Having: lambda self: self._parse_having(),
|
||||
exp.With: lambda self: self._parse_with(),
|
||||
exp.Window: lambda self: self._parse_named_window(),
|
||||
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
|
||||
}
|
||||
|
||||
STATEMENT_PARSERS = {
|
||||
TokenType.ALTER: lambda self: self._parse_alter(),
|
||||
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.CREATE: lambda self: self._parse_create(),
|
||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||
TokenType.DESCRIBE: lambda self: self._parse_describe(),
|
||||
TokenType.DROP: lambda self: self._parse_drop(),
|
||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.INSERT: lambda self: self._parse_insert(),
|
||||
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
|
||||
TokenType.UPDATE: lambda self: self._parse_update(),
|
||||
TokenType.DELETE: lambda self: self._parse_delete(),
|
||||
TokenType.CACHE: lambda self: self._parse_cache(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
||||
TokenType.BEGIN: lambda self: self._parse_transaction(),
|
||||
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.END: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.MERGE: lambda self: self._parse_merge(),
|
||||
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
|
||||
TokenType.UNCACHE: lambda self: self._parse_uncache(),
|
||||
TokenType.UPDATE: lambda self: self._parse_update(),
|
||||
TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
|
||||
}
|
||||
|
||||
UNARY_PARSERS = {
|
||||
|
@ -441,6 +459,7 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
|
||||
TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
|
||||
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
|
||||
TokenType.NATIONAL: lambda self, token: self._parse_national(token),
|
||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||
}
|
||||
|
||||
|
@ -454,6 +473,9 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.ILIKE: lambda self, this: self._parse_escape(
|
||||
self.expression(exp.ILike, this=this, expression=self._parse_bitwise())
|
||||
),
|
||||
TokenType.IRLIKE: lambda self, this: self.expression(
|
||||
exp.RegexpILike, this=this, expression=self._parse_bitwise()
|
||||
),
|
||||
TokenType.RLIKE: lambda self, this: self.expression(
|
||||
exp.RegexpLike, this=this, expression=self._parse_bitwise()
|
||||
),
|
||||
|
@ -535,8 +557,7 @@ class Parser(metaclass=_Parser):
|
|||
"group": lambda self: self._parse_group(),
|
||||
"having": lambda self: self._parse_having(),
|
||||
"qualify": lambda self: self._parse_qualify(),
|
||||
"window": lambda self: self._match(TokenType.WINDOW)
|
||||
and self._parse_window(self._parse_id_var(), alias=True),
|
||||
"windows": lambda self: self._parse_window_clause(),
|
||||
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
|
||||
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
|
||||
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
|
||||
|
@ -551,18 +572,18 @@ class Parser(metaclass=_Parser):
|
|||
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
|
||||
|
||||
CREATABLES = {
|
||||
TokenType.TABLE,
|
||||
TokenType.VIEW,
|
||||
TokenType.COLUMN,
|
||||
TokenType.FUNCTION,
|
||||
TokenType.INDEX,
|
||||
TokenType.PROCEDURE,
|
||||
TokenType.SCHEMA,
|
||||
TokenType.TABLE,
|
||||
TokenType.VIEW,
|
||||
}
|
||||
|
||||
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
|
||||
|
||||
STRICT_CAST = True
|
||||
LATERAL_FUNCTION_AS_VIEW = False
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
|
@ -782,13 +803,16 @@ class Parser(metaclass=_Parser):
|
|||
self._parse_query_modifiers(expression)
|
||||
return expression
|
||||
|
||||
def _parse_drop(self):
|
||||
def _parse_drop(self, default_kind=None):
|
||||
temporary = self._match(TokenType.TEMPORARY)
|
||||
materialized = self._match(TokenType.MATERIALIZED)
|
||||
kind = self._match_set(self.CREATABLES) and self._prev.text
|
||||
if not kind:
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return
|
||||
if default_kind:
|
||||
kind = default_kind
|
||||
else:
|
||||
self.raise_error(f"Expected {self.CREATABLES}")
|
||||
return
|
||||
|
||||
return self.expression(
|
||||
exp.Drop,
|
||||
|
@ -876,7 +900,7 @@ class Parser(metaclass=_Parser):
|
|||
) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
|
||||
|
||||
if assignment:
|
||||
key = self._parse_var() or self._parse_string()
|
||||
key = self._parse_var_or_string()
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.Property, this=key, value=self._parse_column())
|
||||
|
||||
|
@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser):
|
|||
elif (table or nested) and self._match(TokenType.L_PAREN):
|
||||
this = self._parse_table() if table else self._parse_select(nested=True)
|
||||
self._parse_query_modifiers(this)
|
||||
this = self._parse_set_operations(this)
|
||||
self._match_r_paren()
|
||||
this = self._parse_subquery(this)
|
||||
# early return so that subquery unions aren't parsed again
|
||||
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
|
||||
# Union ALL should be a property of the top select node, not the subquery
|
||||
return self._parse_subquery(this)
|
||||
elif self._match(TokenType.VALUES):
|
||||
if self._curr.token_type == TokenType.L_PAREN:
|
||||
# We don't consume the left paren because it's consumed in _parse_value
|
||||
expressions = self._parse_csv(self._parse_value)
|
||||
else:
|
||||
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
|
||||
# Source: https://prestodb.io/docs/current/sql/values.html
|
||||
expressions = self._parse_csv(
|
||||
lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
|
||||
)
|
||||
|
||||
this = self.expression(
|
||||
exp.Values,
|
||||
expressions=self._parse_csv(self._parse_value),
|
||||
expressions=expressions,
|
||||
alias=self._parse_table_alias(),
|
||||
)
|
||||
else:
|
||||
this = None
|
||||
|
||||
return self._parse_set_operations(this) if this else None
|
||||
return self._parse_set_operations(this)
|
||||
|
||||
def _parse_with(self, skip_with_token=False):
|
||||
if not skip_with_token and not self._match(TokenType.WITH):
|
||||
|
@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser):
|
|||
alias = self._parse_id_var(
|
||||
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
|
||||
)
|
||||
columns = None
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
columns = self._parse_csv(lambda: self._parse_id_var(any_token))
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var()))
|
||||
self._match_r_paren()
|
||||
else:
|
||||
columns = None
|
||||
|
||||
if not alias and not columns:
|
||||
return None
|
||||
|
@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser):
|
|||
expression=self._parse_function() or self._parse_id_var(any_token=False),
|
||||
)
|
||||
|
||||
columns = None
|
||||
table_alias = None
|
||||
if view or self.LATERAL_FUNCTION_AS_VIEW:
|
||||
table_alias = self._parse_id_var(any_token=False)
|
||||
if self._match(TokenType.ALIAS):
|
||||
columns = self._parse_csv(self._parse_id_var)
|
||||
if view:
|
||||
table = self._parse_id_var(any_token=False)
|
||||
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
|
||||
table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
|
||||
else:
|
||||
self._match(TokenType.ALIAS)
|
||||
table_alias = self._parse_id_var(any_token=False)
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
columns = self._parse_csv(self._parse_id_var)
|
||||
self._match_r_paren()
|
||||
table_alias = self._parse_table_alias()
|
||||
|
||||
expression = self.expression(
|
||||
exp.Lateral,
|
||||
this=this,
|
||||
view=view,
|
||||
outer=outer,
|
||||
alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
|
||||
alias=table_alias,
|
||||
)
|
||||
|
||||
if outer_apply or cross_apply:
|
||||
|
@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser):
|
|||
if negate:
|
||||
this = self.expression(exp.Not, this=this)
|
||||
|
||||
if self._match(TokenType.IS):
|
||||
this = self._parse_is(this)
|
||||
|
||||
return this
|
||||
|
||||
def _parse_is(self, this):
|
||||
|
@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser):
|
|||
return None
|
||||
|
||||
type_token = self._prev.token_type
|
||||
|
||||
if type_token == TokenType.PSEUDO_TYPE:
|
||||
return self.expression(exp.PseudoType, this=self._prev.text)
|
||||
|
||||
nested = type_token in self.NESTED_TYPE_TOKENS
|
||||
is_struct = type_token == TokenType.STRUCT
|
||||
expressions = None
|
||||
|
@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if value is None:
|
||||
value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
|
||||
elif type_token == TokenType.INTERVAL:
|
||||
value = self.expression(exp.Interval, unit=self._parse_var())
|
||||
|
||||
if maybe_func and check_func:
|
||||
index2 = self._index
|
||||
|
@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
def _parse_primary(self):
|
||||
if self._match_set(self.PRIMARY_PARSERS):
|
||||
return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
|
||||
token_type = self._prev.token_type
|
||||
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
|
||||
|
||||
if token_type == TokenType.STRING:
|
||||
expressions = [primary]
|
||||
while self._match(TokenType.STRING):
|
||||
expressions.append(exp.Literal.string(self._prev.text))
|
||||
if len(expressions) > 1:
|
||||
return self.expression(exp.Concat, expressions=expressions)
|
||||
return primary
|
||||
|
||||
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
|
||||
return exp.Literal.number(f"0.{self._prev.text}")
|
||||
|
@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Identifier, this=token.text)
|
||||
|
||||
def _parse_national(self, token):
|
||||
return self.expression(exp.National, this=exp.Literal.string(token.text))
|
||||
|
||||
def _parse_session_parameter(self):
|
||||
kind = None
|
||||
this = self._parse_id_var() or self._parse_primary()
|
||||
|
@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
expressions = self._parse_csv(self._parse_id_var)
|
||||
self._match(TokenType.R_PAREN)
|
||||
|
||||
if not self._match(TokenType.R_PAREN):
|
||||
self._retreat(index)
|
||||
else:
|
||||
expressions = [self._parse_id_var()]
|
||||
|
||||
|
@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser):
|
|||
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
|
||||
)
|
||||
else:
|
||||
this = self._parse_conjunction()
|
||||
this = self._parse_select_or_expression()
|
||||
|
||||
if self._match(TokenType.IGNORE_NULLS):
|
||||
this = self.expression(exp.IgnoreNulls, this=this)
|
||||
else:
|
||||
self._match(TokenType.RESPECT_NULLS)
|
||||
|
||||
return self._parse_alias(self._parse_limit(self._parse_order(this)))
|
||||
return self._parse_limit(self._parse_order(this))
|
||||
|
||||
def _parse_schema(self, this=None):
|
||||
index = self._index
|
||||
|
@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
args = self._parse_csv(
|
||||
lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
|
||||
lambda: self._parse_constraint()
|
||||
or self._parse_column_def(self._parse_field(any_token=True))
|
||||
)
|
||||
self._match_r_paren()
|
||||
return self.expression(exp.Schema, this=this, expressions=args)
|
||||
|
@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser):
|
|||
elif self._match(TokenType.ENCODE):
|
||||
kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
|
||||
elif self._match(TokenType.DEFAULT):
|
||||
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
|
||||
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise())
|
||||
elif self._match_pair(TokenType.NOT, TokenType.NULL):
|
||||
kind = exp.NotNullColumnConstraint()
|
||||
elif self._match(TokenType.NULL):
|
||||
|
@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser):
|
|||
if not self._match(TokenType.L_BRACKET):
|
||||
return this
|
||||
|
||||
expressions = self._parse_csv(self._parse_conjunction)
|
||||
if self._match(TokenType.COLON):
|
||||
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
|
||||
else:
|
||||
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
|
||||
|
||||
if not this or this.name.upper() == "ARRAY":
|
||||
this = self.expression(exp.Array, expressions=expressions)
|
||||
|
@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser):
|
|||
this.comments = self._prev_comments
|
||||
return self._parse_bracket(this)
|
||||
|
||||
def _parse_slice(self, this):
|
||||
if self._match(TokenType.COLON):
|
||||
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
|
||||
return this
|
||||
|
||||
def _parse_case(self):
|
||||
ifs = []
|
||||
default = None
|
||||
|
@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser):
|
|||
collation=collation,
|
||||
)
|
||||
|
||||
def _parse_window_clause(self):
|
||||
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
|
||||
|
||||
def _parse_named_window(self):
|
||||
return self._parse_window(self._parse_id_var(), alias=True)
|
||||
|
||||
def _parse_window(self, this, alias=False):
|
||||
if self._match(TokenType.FILTER):
|
||||
where = self._parse_wrapped(self._parse_where)
|
||||
|
@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser):
|
|||
if identifier:
|
||||
return identifier
|
||||
|
||||
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||
self._advance()
|
||||
elif not self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
return None
|
||||
return exp.Identifier(this=self._prev.text, quoted=False)
|
||||
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
|
||||
return exp.Identifier(this=self._prev.text, quoted=False)
|
||||
return None
|
||||
|
||||
def _parse_string(self):
|
||||
if self._match(TokenType.STRING):
|
||||
|
@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _parse_var(self):
|
||||
if self._match(TokenType.VAR):
|
||||
def _parse_var(self, any_token=False):
|
||||
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
|
||||
return self.expression(exp.Var, this=self._prev.text)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _advance_any(self):
|
||||
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
|
||||
self._advance()
|
||||
return self._prev
|
||||
return None
|
||||
|
||||
def _parse_var_or_string(self):
|
||||
return self._parse_var() or self._parse_string()
|
||||
|
||||
|
@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser):
|
|||
if self._match(TokenType.PLACEHOLDER):
|
||||
return self.expression(exp.Placeholder)
|
||||
elif self._match(TokenType.COLON):
|
||||
self._advance()
|
||||
return self.expression(exp.Placeholder, this=self._prev.text)
|
||||
if self._match_set((TokenType.NUMBER, TokenType.VAR)):
|
||||
return self.expression(exp.Placeholder, this=self._prev.text)
|
||||
self._advance(-1)
|
||||
return None
|
||||
|
||||
def _parse_except(self):
|
||||
|
@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Rollback, savepoint=savepoint)
|
||||
return self.expression(exp.Commit, chain=chain)
|
||||
|
||||
def _parse_add_column(self):
|
||||
if not self._match_text_seq("ADD"):
|
||||
return None
|
||||
|
||||
self._match(TokenType.COLUMN)
|
||||
exists_column = self._parse_exists(not_=True)
|
||||
expression = self._parse_column_def(self._parse_field(any_token=True))
|
||||
expression.set("exists", exists_column)
|
||||
return expression
|
||||
|
||||
def _parse_drop_column(self):
|
||||
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
|
||||
|
||||
def _parse_alter(self):
|
||||
if not self._match(TokenType.TABLE):
|
||||
return None
|
||||
|
||||
exists = self._parse_exists()
|
||||
this = self._parse_table(schema=True)
|
||||
|
||||
actions = None
|
||||
if self._match_text_seq("ADD", advance=False):
|
||||
actions = self._parse_csv(self._parse_add_column)
|
||||
elif self._match_text_seq("DROP", advance=False):
|
||||
actions = self._parse_csv(self._parse_drop_column)
|
||||
elif self._match_text_seq("ALTER"):
|
||||
self._match(TokenType.COLUMN)
|
||||
column = self._parse_field(any_token=True)
|
||||
|
||||
if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
|
||||
actions = self.expression(exp.AlterColumn, this=column, drop=True)
|
||||
elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
|
||||
actions = self.expression(
|
||||
exp.AlterColumn, this=column, default=self._parse_conjunction()
|
||||
)
|
||||
else:
|
||||
self._match_text_seq("SET", "DATA")
|
||||
actions = self.expression(
|
||||
exp.AlterColumn,
|
||||
this=column,
|
||||
dtype=self._match_text_seq("TYPE") and self._parse_types(),
|
||||
collate=self._match(TokenType.COLLATE) and self._parse_term(),
|
||||
using=self._match(TokenType.USING) and self._parse_conjunction(),
|
||||
)
|
||||
|
||||
actions = ensure_list(actions)
|
||||
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
|
||||
|
||||
def _parse_show(self):
|
||||
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
|
||||
if parser:
|
||||
|
@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _match_text_seq(self, *texts):
|
||||
def _match_text_seq(self, *texts, advance=True):
|
||||
index = self._index
|
||||
for text in texts:
|
||||
if self._curr and self._curr.text.upper() == text:
|
||||
|
@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser):
|
|||
else:
|
||||
self._retreat(index)
|
||||
return False
|
||||
|
||||
if not advance:
|
||||
self._retreat(index)
|
||||
|
||||
return True
|
||||
|
||||
def _replace_columns_with_dots(self, this):
|
||||
|
|
|
@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
super().__init__(schema)
|
||||
self.visible = visible or {}
|
||||
self.dialect = dialect
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {
|
||||
"STR": exp.DataType.build("text"),
|
||||
}
|
||||
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
|
||||
|
||||
@classmethod
|
||||
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
|
||||
|
|
|
@ -48,6 +48,7 @@ class TokenType(AutoName):
|
|||
DOLLAR = auto()
|
||||
PARAMETER = auto()
|
||||
SESSION_PARAMETER = auto()
|
||||
NATIONAL = auto()
|
||||
|
||||
BLOCK_START = auto()
|
||||
BLOCK_END = auto()
|
||||
|
@ -111,6 +112,7 @@ class TokenType(AutoName):
|
|||
|
||||
# keywords
|
||||
ALIAS = auto()
|
||||
ALTER = auto()
|
||||
ALWAYS = auto()
|
||||
ALL = auto()
|
||||
ANTI = auto()
|
||||
|
@ -196,6 +198,7 @@ class TokenType(AutoName):
|
|||
INTERVAL = auto()
|
||||
INTO = auto()
|
||||
INTRODUCER = auto()
|
||||
IRLIKE = auto()
|
||||
IS = auto()
|
||||
ISNULL = auto()
|
||||
JOIN = auto()
|
||||
|
@ -241,6 +244,7 @@ class TokenType(AutoName):
|
|||
PRIMARY_KEY = auto()
|
||||
PROCEDURE = auto()
|
||||
PROPERTIES = auto()
|
||||
PSEUDO_TYPE = auto()
|
||||
QUALIFY = auto()
|
||||
QUOTE = auto()
|
||||
RANGE = auto()
|
||||
|
@ -346,7 +350,11 @@ class _Tokenizer(type):
|
|||
def __new__(cls, clsname, bases, attrs): # type: ignore
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
|
||||
klass._QUOTES = {
|
||||
f"{prefix}{s}": e
|
||||
for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items()
|
||||
for prefix in (("",) if s[0].isalpha() else ("", "n", "N"))
|
||||
}
|
||||
klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS)
|
||||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||
|
@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"CHECK": TokenType.CHECK,
|
||||
"CLUSTER BY": TokenType.CLUSTER_BY,
|
||||
"COLLATE": TokenType.COLLATE,
|
||||
"COLUMN": TokenType.COLUMN,
|
||||
"COMMENT": TokenType.SCHEMA_COMMENT,
|
||||
"COMMIT": TokenType.COMMIT,
|
||||
"COMPOUND": TokenType.COMPOUND,
|
||||
|
@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"SEMI": TokenType.SEMI,
|
||||
"SET": TokenType.SET,
|
||||
"SHOW": TokenType.SHOW,
|
||||
"SIMILAR TO": TokenType.SIMILAR_TO,
|
||||
"SOME": TokenType.SOME,
|
||||
"SORTKEY": TokenType.SORTKEY,
|
||||
"SORT BY": TokenType.SORT_BY,
|
||||
|
@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"VOLATILE": TokenType.VOLATILE,
|
||||
"WHEN": TokenType.WHEN,
|
||||
"WHERE": TokenType.WHERE,
|
||||
"WINDOW": TokenType.WINDOW,
|
||||
"WITH": TokenType.WITH,
|
||||
"WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
|
||||
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
|
||||
|
@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR": TokenType.NVARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"STR": TokenType.TEXT,
|
||||
"STRING": TokenType.TEXT,
|
||||
"TEXT": TokenType.TEXT,
|
||||
"CLOB": TokenType.TEXT,
|
||||
|
@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"UNIQUE": TokenType.UNIQUE,
|
||||
"STRUCT": TokenType.STRUCT,
|
||||
"VARIANT": TokenType.VARIANT,
|
||||
"ALTER": TokenType.COMMAND,
|
||||
"ALTER": TokenType.ALTER,
|
||||
"ALTER AGGREGATE": TokenType.COMMAND,
|
||||
"ALTER DEFAULT": TokenType.COMMAND,
|
||||
"ALTER DOMAIN": TokenType.COMMAND,
|
||||
"ALTER ROLE": TokenType.COMMAND,
|
||||
"ALTER RULE": TokenType.COMMAND,
|
||||
"ALTER SEQUENCE": TokenType.COMMAND,
|
||||
"ALTER TYPE": TokenType.COMMAND,
|
||||
"ALTER USER": TokenType.COMMAND,
|
||||
"ALTER VIEW": TokenType.COMMAND,
|
||||
"ANALYZE": TokenType.COMMAND,
|
||||
"CALL": TokenType.COMMAND,
|
||||
"EXPLAIN": TokenType.COMMAND,
|
||||
|
@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
text = self._extract_string(quote_end)
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
|
||||
text = text.replace("\\\\", "\\") if self._replace_backslash else text
|
||||
self._add(TokenType.STRING, text)
|
||||
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
|
||||
return True
|
||||
|
||||
# X'1234, b'0110', E'\\\\\' etc.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue