Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
407314e8d2
commit
efc1e37108
67 changed files with 2461 additions and 840 deletions
|
@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
|
|||
from sqlglot.dialects.clickhouse import ClickHouse
|
||||
from sqlglot.dialects.databricks import Databricks
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.dialects.drill import Drill
|
||||
from sqlglot.dialects.duckdb import DuckDB
|
||||
from sqlglot.dialects.hive import Hive
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
|
|
|
@ -119,6 +119,8 @@ class BigQuery(Dialect):
|
|||
"UNKNOWN": TokenType.NULL,
|
||||
"WINDOW": TokenType.WINDOW,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
}
|
||||
KEYWORDS.pop("DIV")
|
||||
|
||||
|
@ -204,6 +206,15 @@ class BigQuery(Dialect):
|
|||
|
||||
EXPLICIT_UNION = True
|
||||
|
||||
def transaction_sql(self, *_):
|
||||
return "BEGIN TRANSACTION"
|
||||
|
||||
def commit_sql(self, *_):
|
||||
return "COMMIT TRANSACTION"
|
||||
|
||||
def rollback_sql(self, *_):
|
||||
return "ROLLBACK TRANSACTION"
|
||||
|
||||
def in_unnest_op(self, unnest):
|
||||
return self.sql(unnest)
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ class Dialects(str, Enum):
|
|||
TRINO = "trino"
|
||||
TSQL = "tsql"
|
||||
DATABRICKS = "databricks"
|
||||
DRILL = "drill"
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
|
@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
|
|||
return exp_class(this=this, expression=expression, unit=unit)
|
||||
|
||||
return inner_func
|
||||
|
||||
|
||||
def locate_to_strposition(args):
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
)
|
||||
|
||||
|
||||
def strposition_to_local_sql(self, expression):
|
||||
args = self.format_args(
|
||||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||
)
|
||||
return f"LOCATE({args})"
|
||||
|
|
174
sqlglot/dialects/drill.py
Normal file
174
sqlglot/dialects/drill.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
no_pivot_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
str_position_sql,
|
||||
)
|
||||
from sqlglot.dialects.postgres import _lateral_sql
|
||||
|
||||
|
||||
def _to_timestamp(args):
|
||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
||||
if len(args) == 1 and args[0].is_number:
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
return format_time_lambda(exp.StrToTime, "drill")(args)
|
||||
|
||||
|
||||
def _str_to_time_sql(self, expression):
|
||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self, expression):
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def _date_add_sql(kind):
|
||||
def func(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
unit = expression.text("unit").upper() or "DAY"
|
||||
expression = self.sql(expression, "expression")
|
||||
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def if_sql(self, expression):
|
||||
"""
|
||||
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||
adds the backticks around the keyword IF.
|
||||
Args:
|
||||
self: The Drill dialect
|
||||
expression: The input IF expression
|
||||
|
||||
Returns: The expression with IF in backticks.
|
||||
|
||||
"""
|
||||
expressions = self.format_args(
|
||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||
)
|
||||
return f"`IF`({expressions})"
|
||||
|
||||
|
||||
def _str_to_date(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Drill.date_format:
|
||||
return f"CAST({this} AS DATE)"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
class Drill(Dialect):
|
||||
normalize_functions = None
|
||||
null_ordering = "nulls_are_last"
|
||||
date_format = "'yyyy-MM-dd'"
|
||||
dateint_format = "'yyyyMMdd'"
|
||||
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||
|
||||
time_mapping = {
|
||||
"y": "%Y",
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
"yyyy": "%Y",
|
||||
"YY": "%y",
|
||||
"yy": "%y",
|
||||
"MMMM": "%B",
|
||||
"MMM": "%b",
|
||||
"MM": "%m",
|
||||
"M": "%-m",
|
||||
"dd": "%d",
|
||||
"d": "%-d",
|
||||
"HH": "%H",
|
||||
"H": "%-H",
|
||||
"hh": "%I",
|
||||
"h": "%-I",
|
||||
"mm": "%M",
|
||||
"m": "%-M",
|
||||
"ss": "%S",
|
||||
"s": "%-S",
|
||||
"SSSSSS": "%f",
|
||||
"a": "%p",
|
||||
"DD": "%j",
|
||||
"D": "%-j",
|
||||
"E": "%a",
|
||||
"EE": "%a",
|
||||
"EEE": "%a",
|
||||
"EEEE": "%A",
|
||||
"''T''": "T",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'"]
|
||||
IDENTIFIERS = ["`"]
|
||||
ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
TYPE_MAPPING = {
|
||||
**generator.Generator.TYPE_MAPPING,
|
||||
exp.DataType.Type.INT: "INTEGER",
|
||||
exp.DataType.Type.SMALLINT: "INTEGER",
|
||||
exp.DataType.Type.TINYINT: "INTEGER",
|
||||
exp.DataType.Type.BINARY: "VARBINARY",
|
||||
exp.DataType.Type.TEXT: "VARCHAR",
|
||||
exp.DataType.Type.NCHAR: "VARCHAR",
|
||||
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.DATETIME: "TIMESTAMP",
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {exp.PartitionedByProperty}
|
||||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.Lateral: _lateral_sql,
|
||||
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
|
||||
exp.ArraySize: rename_func("REPEATED_COUNT"),
|
||||
exp.Create: create_with_partitions_sql,
|
||||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||
exp.If: if_sql,
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
|
||||
exp.StrPosition: str_position_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
|
||||
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
||||
def normalize_func(self, name):
|
||||
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
|
@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
|
|||
|
||||
def _sort_array_sql(self, expression):
|
||||
this = self.sql(expression, "this")
|
||||
if expression.args.get("asc") == exp.FALSE:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
return f"ARRAY_REVERSE_SORT({this})"
|
||||
return f"ARRAY_SORT({this})"
|
||||
|
||||
|
||||
def _sort_array_reverse(args):
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
|
||||
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
|
||||
|
||||
|
||||
def _struct_pack_sql(self, expression):
|
||||
|
|
|
@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
|
|||
create_with_partitions_sql,
|
||||
format_time_lambda,
|
||||
if_sql,
|
||||
locate_to_strposition,
|
||||
no_ilike_sql,
|
||||
no_recursive_cte_sql,
|
||||
no_safe_divide_sql,
|
||||
no_trycast_sql,
|
||||
rename_func,
|
||||
strposition_to_local_sql,
|
||||
struct_extract_sql,
|
||||
var_map_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.parser import parse_var_map
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
# (FuncType, Multiplier)
|
||||
DATE_DELTA_INTERVAL = {
|
||||
|
@ -181,6 +184,15 @@ class Hive(Dialect):
|
|||
"F": "FLOAT",
|
||||
"BD": "DECIMAL",
|
||||
}
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ADD ARCHIVE": TokenType.COMMAND,
|
||||
"ADD ARCHIVES": TokenType.COMMAND,
|
||||
"ADD FILE": TokenType.COMMAND,
|
||||
"ADD FILES": TokenType.COMMAND,
|
||||
"ADD JAR": TokenType.COMMAND,
|
||||
"ADD JARS": TokenType.COMMAND,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
@ -210,11 +222,7 @@ class Hive(Dialect):
|
|||
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
|
||||
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
|
||||
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
|
||||
"LOCATE": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
),
|
||||
"LOCATE": locate_to_strposition,
|
||||
"LOG": (
|
||||
lambda args: exp.Log.from_arg_list(args)
|
||||
if len(args) > 1
|
||||
|
@ -272,7 +280,7 @@ class Hive(Dialect):
|
|||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.SetAgg: rename_func("COLLECT_SET"),
|
||||
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
|
||||
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
|
||||
exp.StrPosition: strposition_to_local_sql,
|
||||
exp.StrToDate: _str_to_date,
|
||||
exp.StrToTime: _str_to_time,
|
||||
exp.StrToUnix: _str_to_unix,
|
||||
|
|
|
@ -5,10 +5,12 @@ import typing as t
|
|||
from sqlglot import exp, generator, parser, tokens
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
locate_to_strposition,
|
||||
no_ilike_sql,
|
||||
no_paren_current_date_sql,
|
||||
no_tablesample_sql,
|
||||
no_trycast_sql,
|
||||
strposition_to_local_sql,
|
||||
)
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -120,6 +122,7 @@ class MySQL(Dialect):
|
|||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
|
@ -172,13 +175,18 @@ class MySQL(Dialect):
|
|||
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_ADD": _date_add(exp.DateAdd),
|
||||
"DATE_SUB": _date_add(exp.DateSub),
|
||||
"STR_TO_DATE": _str_to_date,
|
||||
"LOCATE": locate_to_strposition,
|
||||
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"LEFT": lambda args: exp.Substring(
|
||||
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
|
||||
),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -264,6 +272,7 @@ class MySQL(Dialect):
|
|||
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
|
||||
"NAMES": lambda self: self._parse_set_item_names(),
|
||||
"TRANSACTION": lambda self: self._parse_set_transaction(),
|
||||
}
|
||||
|
||||
PROFILE_TYPES = {
|
||||
|
@ -278,39 +287,48 @@ class MySQL(Dialect):
|
|||
"SWAPS",
|
||||
}
|
||||
|
||||
TRANSACTION_CHARACTERISTICS = {
|
||||
"ISOLATION LEVEL REPEATABLE READ",
|
||||
"ISOLATION LEVEL READ COMMITTED",
|
||||
"ISOLATION LEVEL READ UNCOMMITTED",
|
||||
"ISOLATION LEVEL SERIALIZABLE",
|
||||
"READ WRITE",
|
||||
"READ ONLY",
|
||||
}
|
||||
|
||||
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
|
||||
if target:
|
||||
if isinstance(target, str):
|
||||
self._match_text(target)
|
||||
self._match_text_seq(target)
|
||||
target_id = self._parse_id_var()
|
||||
else:
|
||||
target_id = None
|
||||
|
||||
log = self._parse_string() if self._match_text("IN") else None
|
||||
log = self._parse_string() if self._match_text_seq("IN") else None
|
||||
|
||||
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
|
||||
position = self._parse_number() if self._match_text("FROM") else None
|
||||
position = self._parse_number() if self._match_text_seq("FROM") else None
|
||||
db = None
|
||||
else:
|
||||
position = None
|
||||
db = self._parse_id_var() if self._match_text("FROM") else None
|
||||
db = self._parse_id_var() if self._match_text_seq("FROM") else None
|
||||
|
||||
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
|
||||
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
|
||||
|
||||
like = self._parse_string() if self._match_text("LIKE") else None
|
||||
like = self._parse_string() if self._match_text_seq("LIKE") else None
|
||||
where = self._parse_where()
|
||||
|
||||
if this == "PROFILE":
|
||||
types = self._parse_csv(self._parse_show_profile_type)
|
||||
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
|
||||
offset = self._parse_number() if self._match_text("OFFSET") else None
|
||||
limit = self._parse_number() if self._match_text("LIMIT") else None
|
||||
types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
|
||||
query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
|
||||
offset = self._parse_number() if self._match_text_seq("OFFSET") else None
|
||||
limit = self._parse_number() if self._match_text_seq("LIMIT") else None
|
||||
else:
|
||||
types, query = None, None
|
||||
offset, limit = self._parse_oldstyle_limit()
|
||||
|
||||
mutex = True if self._match_text("MUTEX") else None
|
||||
mutex = False if self._match_text("STATUS") else mutex
|
||||
mutex = True if self._match_text_seq("MUTEX") else None
|
||||
mutex = False if self._match_text_seq("STATUS") else mutex
|
||||
|
||||
return self.expression(
|
||||
exp.Show,
|
||||
|
@ -331,16 +349,16 @@ class MySQL(Dialect):
|
|||
**{"global": global_},
|
||||
)
|
||||
|
||||
def _parse_show_profile_type(self):
|
||||
for type_ in self.PROFILE_TYPES:
|
||||
if self._match_text(*type_.split(" ")):
|
||||
return exp.Var(this=type_)
|
||||
def _parse_var_from_options(self, options):
|
||||
for option in options:
|
||||
if self._match_text_seq(*option.split(" ")):
|
||||
return exp.Var(this=option)
|
||||
return None
|
||||
|
||||
def _parse_oldstyle_limit(self):
|
||||
limit = None
|
||||
offset = None
|
||||
if self._match_text("LIMIT"):
|
||||
if self._match_text_seq("LIMIT"):
|
||||
parts = self._parse_csv(self._parse_number)
|
||||
if len(parts) == 1:
|
||||
limit = parts[0]
|
||||
|
@ -353,6 +371,9 @@ class MySQL(Dialect):
|
|||
return self._parse_set_item_assignment(kind=None)
|
||||
|
||||
def _parse_set_item_assignment(self, kind):
|
||||
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
|
||||
return self._parse_set_transaction(global_=kind == "GLOBAL")
|
||||
|
||||
left = self._parse_primary() or self._parse_id_var()
|
||||
if not self._match(TokenType.EQ):
|
||||
self.raise_error("Expected =")
|
||||
|
@ -381,7 +402,7 @@ class MySQL(Dialect):
|
|||
|
||||
def _parse_set_item_names(self):
|
||||
charset = self._parse_string() or self._parse_id_var()
|
||||
if self._match_text("COLLATE"):
|
||||
if self._match_text_seq("COLLATE"):
|
||||
collate = self._parse_string() or self._parse_id_var()
|
||||
else:
|
||||
collate = None
|
||||
|
@ -392,6 +413,18 @@ class MySQL(Dialect):
|
|||
kind="NAMES",
|
||||
)
|
||||
|
||||
def _parse_set_transaction(self, global_=False):
|
||||
self._match_text_seq("TRANSACTION")
|
||||
characteristics = self._parse_csv(
|
||||
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
|
||||
)
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
expressions=characteristics,
|
||||
kind="TRANSACTION",
|
||||
**{"global": global_},
|
||||
)
|
||||
|
||||
class Generator(generator.Generator):
|
||||
NULL_ORDERING_SUPPORTED = False
|
||||
|
||||
|
@ -411,6 +444,7 @@ class MySQL(Dialect):
|
|||
exp.Trim: _trim_sql,
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.StrPosition: strposition_to_local_sql,
|
||||
}
|
||||
|
||||
ROOT_PROPERTIES = {
|
||||
|
@ -481,9 +515,11 @@ class MySQL(Dialect):
|
|||
kind = self.sql(expression, "kind")
|
||||
kind = f"{kind} " if kind else ""
|
||||
this = self.sql(expression, "this")
|
||||
expressions = self.expressions(expression)
|
||||
collate = self.sql(expression, "collate")
|
||||
collate = f" COLLATE {collate}" if collate else ""
|
||||
return f"{kind}{this}{collate}"
|
||||
global_ = "GLOBAL " if expression.args.get("global") else ""
|
||||
return f"{global_}{kind}{this}{expressions}{collate}"
|
||||
|
||||
def set_sql(self, expression):
|
||||
return f"SET {self.expressions(expression)}"
|
||||
|
|
|
@ -91,6 +91,7 @@ class Oracle(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
"VARCHAR2": TokenType.VARCHAR,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
|
|
|
@ -164,11 +164,34 @@ class Postgres(Dialect):
|
|||
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
|
||||
|
||||
CREATABLES = (
|
||||
"AGGREGATE",
|
||||
"CAST",
|
||||
"CONVERSION",
|
||||
"COLLATION",
|
||||
"DEFAULT CONVERSION",
|
||||
"CONSTRAINT",
|
||||
"DOMAIN",
|
||||
"EXTENSION",
|
||||
"FOREIGN",
|
||||
"FUNCTION",
|
||||
"OPERATOR",
|
||||
"POLICY",
|
||||
"ROLE",
|
||||
"RULE",
|
||||
"SEQUENCE",
|
||||
"TEXT",
|
||||
"TRIGGER",
|
||||
"TYPE",
|
||||
"UNLOGGED",
|
||||
"USER",
|
||||
)
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"ALWAYS": TokenType.ALWAYS,
|
||||
"BY DEFAULT": TokenType.BY_DEFAULT,
|
||||
"COMMENT ON": TokenType.COMMENT_ON,
|
||||
"IDENTITY": TokenType.IDENTITY,
|
||||
"GENERATED": TokenType.GENERATED,
|
||||
"DOUBLE PRECISION": TokenType.DOUBLE,
|
||||
|
@ -176,6 +199,19 @@ class Postgres(Dialect):
|
|||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
"UUID": TokenType.UUID,
|
||||
"TEMP": TokenType.TEMPORARY,
|
||||
"BEGIN TRANSACTION": TokenType.BEGIN,
|
||||
"BEGIN": TokenType.COMMAND,
|
||||
"COMMENT ON": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"REFRESH": TokenType.COMMAND,
|
||||
"REINDEX": TokenType.COMMAND,
|
||||
"RESET": TokenType.COMMAND,
|
||||
"REVOKE": TokenType.COMMAND,
|
||||
"GRANT": TokenType.COMMAND,
|
||||
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
|
||||
}
|
||||
QUOTES = ["'", "$$"]
|
||||
SINGLE_TOKENS = {
|
||||
|
|
|
@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
|
|||
struct_extract_sql,
|
||||
)
|
||||
from sqlglot.dialects.mysql import MySQL
|
||||
from sqlglot.errors import UnsupportedError
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
|
|||
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
|
||||
|
||||
|
||||
def _decode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
return f"FROM_UTF8({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _encode_sql(self, expression):
|
||||
_ensure_utf8(expression.args.get("charset"))
|
||||
return f"TO_UTF8({self.sql(expression, 'this')})"
|
||||
|
||||
|
||||
def _no_sort_array(self, expression):
|
||||
if expression.args.get("asc") == exp.FALSE:
|
||||
if expression.args.get("asc") == exp.false():
|
||||
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
|
||||
else:
|
||||
comparator = None
|
||||
|
@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
|
|||
|
||||
def _schema_sql(self, expression):
|
||||
if isinstance(expression.parent, exp.Property):
|
||||
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
|
||||
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
|
||||
return f"ARRAY[{columns}]"
|
||||
|
||||
for schema in expression.parent.find_all(exp.Schema):
|
||||
|
@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
|
|||
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
|
||||
|
||||
|
||||
def _ensure_utf8(charset):
|
||||
if charset.name.lower() != "utf-8":
|
||||
raise UnsupportedError(f"Unsupported charset {charset}")
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
|
@ -115,6 +131,7 @@ class Presto(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
"ROW": TokenType.STRUCT,
|
||||
}
|
||||
|
||||
|
@ -140,6 +157,14 @@ class Presto(Dialect):
|
|||
"STRPOS": exp.StrPosition.from_arg_list,
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
|
||||
"FROM_HEX": exp.Unhex.from_arg_list,
|
||||
"TO_HEX": exp.Hex.from_arg_list,
|
||||
"TO_UTF8": lambda args: exp.Encode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
"FROM_UTF8": lambda args: exp.Decode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
}
|
||||
|
||||
class Generator(generator.Generator):
|
||||
|
@ -187,7 +212,10 @@ class Presto(Dialect):
|
|||
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
|
||||
exp.Decode: _decode_sql,
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.Hex: rename_func("TO_HEX"),
|
||||
exp.If: if_sql,
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.Initcap: _initcap_sql,
|
||||
|
@ -212,7 +240,13 @@ class Presto(Dialect):
|
|||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||
exp.Unhex: rename_func("FROM_HEX"),
|
||||
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
|
||||
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
|
||||
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
modes = expression.args.get("modes")
|
||||
modes = f" {', '.join(modes)}" if modes else ""
|
||||
return f"START TRANSACTION{modes}"
|
||||
|
|
|
@ -148,6 +148,7 @@ class Snowflake(Dialect):
|
|||
**parser.Parser.FUNCTION_PARSERS,
|
||||
"DATE_PART": _parse_date_part,
|
||||
}
|
||||
FUNCTION_PARSERS.pop("TRIM")
|
||||
|
||||
FUNC_TOKENS = {
|
||||
*parser.Parser.FUNC_TOKENS,
|
||||
|
@ -203,6 +204,7 @@ class Snowflake(Dialect):
|
|||
exp.StrPosition: rename_func("POSITION"),
|
||||
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
|
||||
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
|
||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
|
|
@ -63,3 +63,8 @@ class SQLite(Dialect):
|
|||
exp.TableSample: no_tablesample_sql,
|
||||
exp.TryCast: no_trycast_sql,
|
||||
}
|
||||
|
||||
def transaction_sql(self, expression):
|
||||
this = expression.this
|
||||
this = f" {this}" if this else ""
|
||||
return f"BEGIN{this} TRANSACTION"
|
||||
|
|
|
@ -248,7 +248,7 @@ class TSQL(Dialect):
|
|||
def _parse_convert(self, strict):
|
||||
to = self._parse_types()
|
||||
self._match(TokenType.COMMA)
|
||||
this = self._parse_column()
|
||||
this = self._parse_conjunction()
|
||||
|
||||
# Retrieve length of datatype and override to default if not specified
|
||||
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue