1
0
Fork 0

Merging upstream version 10.0.8.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:54:32 +01:00
parent 407314e8d2
commit efc1e37108
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
67 changed files with 2461 additions and 840 deletions

View file

@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL

View file

@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
def transaction_sql(self, *_):
return "BEGIN TRANSACTION"
def commit_sql(self, *_):
return "COMMIT TRANSACTION"
def rollback_sql(self, *_):
return "ROLLBACK TRANSACTION"
def in_unnest_op(self, unnest):
return self.sql(unnest)

View file

@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
class _Dialect(type):
@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
def locate_to_strposition(args):
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
)
def strposition_to_local_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"

174
sqlglot/dialects/drill.py Normal file
View file

@ -0,0 +1,174 @@
from __future__ import annotations
import re
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
format_time_lambda,
no_pivot_sql,
no_trycast_sql,
rename_func,
str_position_sql,
)
from sqlglot.dialects.postgres import _lateral_sql
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
return exp.UnixToTime.from_arg_list(args)
return format_time_lambda(exp.StrToTime, "drill")(args)
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
return func
def if_sql(self, expression):
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
Args:
self: The Drill dialect
expression: The input IF expression
Returns: The expression with IF in backticks.
"""
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"`IF`({expressions})"
def _str_to_date(self, expression):
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
return f"CAST({this} AS DATE)"
return f"TO_DATE({this}, {time_format})"
class Drill(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
time_mapping = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
"yy": "%y",
"MMMM": "%B",
"MMM": "%b",
"MM": "%m",
"M": "%-m",
"dd": "%d",
"d": "%-d",
"HH": "%H",
"H": "%-H",
"hh": "%I",
"h": "%-I",
"mm": "%M",
"m": "%-M",
"ss": "%S",
"s": "%-S",
"SSSSSS": "%f",
"a": "%p",
"DD": "%j",
"D": "%-j",
"E": "%a",
"EE": "%a",
"EEE": "%a",
"EEEE": "%A",
"''T''": "T",
}
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.BINARY: "VARBINARY",
exp.DataType.Type.TEXT: "VARCHAR",
exp.DataType.Type.NCHAR: "VARCHAR",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
ROOT_PROPERTIES = {exp.PartitionedByProperty}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.Lateral: _lateral_sql,
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: create_with_partitions_sql,
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
exp.If: if_sql,
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Pivot: no_pivot_sql,
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name):
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"

View file

@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.FALSE:
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):

View file

@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
locate_to_strposition,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
strposition_to_local_sql,
struct_extract_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT",
"BD": "DECIMAL",
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
"ADD ARCHIVES": TokenType.COMMAND,
"ADD FILE": TokenType.COMMAND,
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
}
class Parser(parser.Parser):
STRICT_CAST = False
@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,

View file

@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
locate_to_strposition,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
strposition_to_local_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
STRICT_CAST = False
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
"LOCATE": locate_to_strposition,
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"LEFT": lambda args: exp.Substring(
this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
),
}
FUNCTION_PARSERS = {
@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
"TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS",
}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
"ISOLATION LEVEL READ UNCOMMITTED",
"ISOLATION LEVEL SERIALIZABLE",
"READ WRITE",
"READ ONLY",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None
log = self._parse_string() if self._match_text("IN") else None
log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
db = self._parse_id_var() if self._match_text_seq("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None
like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text_seq("OFFSET") else None
limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
mutex = True if self._match_text_seq("MUTEX") else None
mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression(
exp.Show,
@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_},
)
def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
return exp.Var(this=type_)
def _parse_var_from_options(self, options):
for option in options:
if self._match_text_seq(*option.split(" ")):
return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
return self._parse_set_transaction(global_=kind == "GLOBAL")
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES",
)
def _parse_set_transaction(self, global_=False):
self._match_text_seq("TRANSACTION")
characteristics = self._parse_csv(
lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
)
return self.expression(
exp.SetItem,
expressions=characteristics,
kind="TRANSACTION",
**{"global": global_},
)
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.StrPosition: strposition_to_local_sql,
}
ROOT_PROPERTIES = {
@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"
global_ = "GLOBAL " if expression.args.get("global") else ""
return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,

View file

@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
CREATABLES = (
"AGGREGATE",
"CAST",
"CONVERSION",
"COLLATION",
"DEFAULT CONVERSION",
"CONSTRAINT",
"DOMAIN",
"EXTENSION",
"FOREIGN",
"FUNCTION",
"OPERATOR",
"POLICY",
"ROLE",
"RULE",
"SEQUENCE",
"TEXT",
"TRIGGER",
"TYPE",
"UNLOGGED",
"USER",
)
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
"TEMP": TokenType.TEMPORARY,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BEGIN": TokenType.COMMAND,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"REVOKE": TokenType.COMMAND,
"GRANT": TokenType.COMMAND,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {

View file

@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"FROM_UTF8({self.sql(expression, 'this')})"
def _encode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
return f"TO_UTF8({self.sql(expression, 'this')})"
def _no_sort_array(self, expression):
if expression.args.get("asc") == exp.FALSE:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
def _ensure_utf8(charset):
if charset.name.lower() != "utf-8":
raise UnsupportedError(f"Unsupported charset {charset}")
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
}
class Generator(generator.Generator):
@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
def transaction_sql(self, expression):
modes = expression.args.get("modes")
modes = f" {', '.join(modes)}" if modes else ""
return f"START TRANSACTION{modes}"

View file

@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}
TYPE_MAPPING = {

View file

@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}
def transaction_sql(self, expression):
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"

View file

@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_column()
this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: