1
0
Fork 0

Merging upstream version 10.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:53:05 +01:00
parent 528822bfd4
commit b7d21c45b7
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
98 changed files with 4080 additions and 1666 deletions

View file

@ -1,21 +1,21 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_ilike_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
interval = seq_get(args, 1)
return expression_class(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
@ -23,6 +23,13 @@ def _date_add(expression_class):
return func
def _date_trunc(args):
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(data_type, kind):
def func(self, expression):
this = self.sql(expression, "this")
@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
structs = []
for row in rows:
aliases = [
exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
exp.alias_(value, column_name)
for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
@ -89,18 +97,19 @@ class BigQuery(Dialect):
"%j": "%-j",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPE = "\\"
ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
@ -111,35 +120,40 @@ class BigQuery(Dialect):
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
KEYWORDS.pop("DIV")
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
),
}
NO_PAREN_FUNCTIONS = {
**Parser.NO_PAREN_FUNCTIONS,
**parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
*Parser.NESTED_TYPE_TOKENS,
*parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
class Generator(Generator):
class Generator(generator.Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@ -148,6 +162,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -157,11 +172,13 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",

View file

@ -1,8 +1,9 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.generator import Generator
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
def _lower_func(sql):
@ -14,11 +15,12 @@ class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
@ -30,9 +32,9 @@ class ClickHouse(Dialect):
"TUPLE": TokenType.STRUCT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"MAP": parse_var_map,
}
@ -44,11 +46,11 @@ class ClickHouse(Dialect):
return this
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@ -63,7 +65,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
@ -15,7 +17,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
**Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
from sqlglot.helper import flatten, list_get
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type):
classes = {}
classes: t.Dict[str, Dialect] = {}
@classmethod
def __getitem__(cls, key):
@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
klass.identifier_start, klass.identifier_end = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
normalize_functions = "upper"
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
time_mapping = {}
time_mapping: t.Dict[str, str] = {}
# autofilled
quote_start = None
@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPE,
"escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression):
expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args):
return exp_class(
this=list_get(args, 0),
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions",
[e for e in schema.expressions if e not in partitions],
)
prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
prop.replace(
exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
)
expression.set("this", schema)
return self.create_sql(expression)
@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
this = list_get(args, 2) if unit_based else list_get(args, 0)
expression = list_get(args, 1) if unit_based else list_get(args, 1)
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
def _sort_array_reverse(args):
return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression):
args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
args = [
self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
for e in expression.expressions
]
return f"STRUCT_PACK({', '.join(args)})"
@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
class DuckDB(Dialect):
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@ -92,7 +95,7 @@ class DuckDB(Dialect):
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
),
@ -112,11 +115,11 @@ class DuckDB(Dialect):
"UNNEST": exp.Explode.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@ -160,7 +163,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
var_map_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser, parse_var_map
from sqlglot.tokens import Tokenizer
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
int(expression.text("expression")) * multiplier
if expression.expression.is_number
else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
@ -165,10 +167,10 @@ class Hive(Dialect):
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
ESCAPE = "\\"
ESCAPES = ["\\"]
ENCODE = "utf-8"
NUMERIC_LITERALS = {
@ -180,40 +182,44 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
class Parser(Parser):
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"),
),
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=list_get(args, 0)),
expression=exp.TsOrDsToDate(this=list_get(args, 1)),
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=exp.Mul(
this=list_get(args, 1),
this=seq_get(args, 1),
expression=exp.Literal.number(-1),
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
this=list_get(args, 1),
substr=list_get(args, 0),
position=list_get(args, 2),
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
else exp.Ln.from_arg_list(args)
),
"LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@ -226,15 +232,16 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),

View file

@ -1,4 +1,8 @@
from sqlglot import exp
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
no_ilike_sql,
@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
)
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _show_parser(*args, **kwargs):
def _parse(self):
return self._parse_show_mysql(*args, **kwargs)
return _parse
def _date_trunc_sql(self, expression):
unit = expression.text("unit").lower()
unit = expression.name.lower()
this = self.sql(expression.this)
expr = self.sql(expression.expression)
if unit == "day":
return f"DATE({this})"
return f"DATE({expr})"
if unit == "week":
concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
concat = f"CONCAT(YEAR({this}), ' 1 1')"
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
return f"DATE({this})"
return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
date_format = MySQL.format_time(list_get(args, 1))
return exp.StrToDate(this=list_get(args, 0), format=date_format)
date_format = MySQL.format_time(seq_get(args, 1))
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
@ -66,9 +75,9 @@ def _trim_sql(self, expression):
def _date_add(expression_class):
def func(args):
interval = list_get(args, 1)
interval = seq_get(args, 1)
return expression_class(
this=list_get(args, 0),
this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
@ -101,15 +110,16 @@ class MySQL(Dialect):
"%l": "%-I",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@ -156,20 +166,23 @@ class MySQL(Dialect):
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
"@@": TokenType.SESSION_PARAMETER,
}
class Parser(Parser):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@ -178,15 +191,212 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
class Generator(Generator):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}
SHOW_PARSERS = {
"BINARY LOGS": _show_parser("BINARY LOGS"),
"MASTER LOGS": _show_parser("BINARY LOGS"),
"BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
"CHARACTER SET": _show_parser("CHARACTER SET"),
"CHARSET": _show_parser("CHARACTER SET"),
"COLLATION": _show_parser("COLLATION"),
"FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
"COLUMNS": _show_parser("COLUMNS", target="FROM"),
"CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
"CREATE EVENT": _show_parser("CREATE EVENT", target=True),
"CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
"CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
"CREATE TABLE": _show_parser("CREATE TABLE", target=True),
"CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
"CREATE VIEW": _show_parser("CREATE VIEW", target=True),
"DATABASES": _show_parser("DATABASES"),
"ENGINE": _show_parser("ENGINE", target=True),
"STORAGE ENGINES": _show_parser("ENGINES"),
"ENGINES": _show_parser("ENGINES"),
"ERRORS": _show_parser("ERRORS"),
"EVENTS": _show_parser("EVENTS"),
"FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
"FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
"GRANTS": _show_parser("GRANTS", target="FOR"),
"INDEX": _show_parser("INDEX", target="FROM"),
"MASTER STATUS": _show_parser("MASTER STATUS"),
"OPEN TABLES": _show_parser("OPEN TABLES"),
"PLUGINS": _show_parser("PLUGINS"),
"PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
"PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
"PRIVILEGES": _show_parser("PRIVILEGES"),
"FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
"PROCESSLIST": _show_parser("PROCESSLIST"),
"PROFILE": _show_parser("PROFILE"),
"PROFILES": _show_parser("PROFILES"),
"RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
"REPLICAS": _show_parser("REPLICAS"),
"SLAVE HOSTS": _show_parser("REPLICAS"),
"REPLICA STATUS": _show_parser("REPLICA STATUS"),
"SLAVE STATUS": _show_parser("REPLICA STATUS"),
"GLOBAL STATUS": _show_parser("STATUS", global_=True),
"SESSION STATUS": _show_parser("STATUS"),
"STATUS": _show_parser("STATUS"),
"TABLE STATUS": _show_parser("TABLE STATUS"),
"FULL TABLES": _show_parser("TABLES", full=True),
"TABLES": _show_parser("TABLES"),
"TRIGGERS": _show_parser("TRIGGERS"),
"GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
"SESSION VARIABLES": _show_parser("VARIABLES"),
"VARIABLES": _show_parser("VARIABLES"),
"WARNINGS": _show_parser("WARNINGS"),
}
SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
"SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
"LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
}
PROFILE_TYPES = {
"ALL",
"BLOCK IO",
"CONTEXT SWITCHES",
"CPU",
"IPC",
"MEMORY",
"PAGE FAULTS",
"SOURCE",
"SWAPS",
}
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
self._match_text(target)
target_id = self._parse_id_var()
else:
target_id = None
log = self._parse_string() if self._match_text("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
position = self._parse_number() if self._match_text("FROM") else None
db = None
else:
position = None
db = self._parse_id_var() if self._match_text("FROM") else None
channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
like = self._parse_string() if self._match_text("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
types = self._parse_csv(self._parse_show_profile_type)
query = self._parse_number() if self._match_text("FOR", "QUERY") else None
offset = self._parse_number() if self._match_text("OFFSET") else None
limit = self._parse_number() if self._match_text("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
mutex = True if self._match_text("MUTEX") else None
mutex = False if self._match_text("STATUS") else mutex
return self.expression(
exp.Show,
this=this,
target=target_id,
full=full,
log=log,
position=position,
db=db,
channel=channel,
like=like,
where=where,
types=types,
query=query,
offset=offset,
limit=limit,
mutex=mutex,
**{"global": global_},
)
def _parse_show_profile_type(self):
for type_ in self.PROFILE_TYPES:
if self._match_text(*type_.split(" ")):
return exp.Var(this=type_)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
if self._match_text("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
elif len(parts) == 2:
limit = parts[1]
offset = parts[0]
return offset, limit
def _default_parse_set_item(self):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
right = self._parse_statement() or self._parse_id_var()
this = self.expression(
exp.EQ,
this=left,
expression=right,
)
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_charset(self, kind):
this = self._parse_string() or self._parse_id_var()
return self.expression(
exp.SetItem,
this=this,
kind=kind,
)
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
collate=collate,
kind="NAMES",
)
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
@ -199,6 +409,8 @@ class MySQL(Dialect):
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
}
ROOT_PROPERTIES = {
@ -209,4 +421,69 @@ class MySQL(Dialect):
exp.SchemaCommentProperty,
}
WITH_PROPERTIES = {}
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
global_ = " GLOBAL" if expression.args.get("global") else ""
target = self.sql(expression, "target")
target = f" {target}" if target else ""
if expression.name in {"COLUMNS", "INDEX"}:
target = f" FROM{target}"
elif expression.name == "GRANTS":
target = f" FOR{target}"
db = self._prefixed_sql("FROM", expression, "db")
like = self._prefixed_sql("LIKE", expression, "like")
where = self.sql(expression, "where")
types = self.expressions(expression, key="types")
types = f" {types}" if types else types
query = self._prefixed_sql("FOR QUERY", expression, "query")
if expression.name == "PROFILE":
offset = self._prefixed_sql("OFFSET", expression, "offset")
limit = self._prefixed_sql("LIMIT", expression, "limit")
else:
offset = ""
limit = self._oldstyle_limit_sql(expression)
log = self._prefixed_sql("IN", expression, "log")
position = self._prefixed_sql("FROM", expression, "position")
channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
if expression.name == "ENGINE":
mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
else:
mutex_or_status = ""
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
def _prefixed_sql(self, prefix, expression, arg):
sql = self.sql(expression, arg)
if not sql:
return ""
return f" {prefix} {sql}"
def _oldstyle_limit_sql(self, expression):
limit = self.sql(expression, "limit")
offset = self.sql(expression, "offset")
if limit:
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
def setitem_sql(self, expression):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

View file

@ -1,8 +1,9 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
from sqlglot.generator import Generator
from sqlglot.helper import csv
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
def _limit_sql(self, expression):
@ -36,9 +37,9 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@ -49,11 +50,12 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@ -86,9 +88,9 @@ class Oracle(Dialect):
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
@ -160,12 +160,12 @@ class Postgres(Dialect):
"YYYY": "%Y", # 2015
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
@ -179,31 +179,32 @@ class Postgres(Dialect):
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
class Parser(Parser):
class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,

View file

@ -1,4 +1,6 @@
from sqlglot import exp, transforms
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
@ -110,30 +110,29 @@ class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
time_mapping = MySQL.time_mapping
time_mapping = MySQL.time_mapping # type: ignore
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
**tokens.Tokenizer.KEYWORDS,
"ROW": TokenType.STRUCT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=list_get(args, 2),
expression=list_get(args, 1),
unit=list_get(args, 0),
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@ -143,7 +142,7 @@ class Presto(Dialect):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@ -159,7 +158,7 @@ class Presto(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@ -169,8 +168,8 @@ class Presto(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP,
**generator.Generator.TRANSFORMS,
**transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping,
**Postgres.time_mapping, # type: ignore
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
ESCAPE = "\\"
ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
"VARBYTE": TokenType.BINARY,
"VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
**Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
rename_func,
)
from sqlglot.expressions import Literal
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
def _check_int(s):
@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
raise ValueError(
f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
)
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime(this=first_arg, scale=timescale)
first_arg = list_get(args, 0)
first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
def _unix_to_time(self, expression):
def _unix_to_time_sql(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@ -132,9 +134,9 @@ class Snowflake(Dialect):
"ff6": "%f",
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
@ -143,18 +145,18 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
*Parser.FUNC_TOKENS,
*parser.Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
**Parser.COLUMN_OPERATORS,
**parser.Parser.COLUMN_OPERATORS, # type: ignore
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@ -163,21 +165,21 @@ class Snowflake(Dialect):
}
PROPERTY_PARSERS = {
**Parser.PROPERTY_PARSERS,
**parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPE = "\\"
ESCAPES = ["\\"]
SINGLE_TOKENS = {
**Tokenizer.SINGLE_TOKENS,
**tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
@ -187,15 +189,15 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
class Generator(Generator):
class Generator(generator.Generator):
CREATE_TRANSIENT = True
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: _unix_to_time,
exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
@ -204,7 +206,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}

View file

@ -1,8 +1,9 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.helper import seq_get
def _create_sql(self, e):
@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
**Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
this=list_get(args, 0),
this=seq_get(args, 0),
start=exp.Literal.number(1),
length=list_get(args, 1),
length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
this=list_get(args, 0),
expression=list_get(args, 1),
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
this=list_get(args, 0),
this=seq_get(args, 0),
start=exp.Sub(
this=exp.Length(this=list_get(args, 0)),
expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
this=exp.Length(this=seq_get(args, 0)),
expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
length=list_get(args, 1),
length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
}
FUNCTION_PARSERS = {
**Parser.FUNCTION_PARSERS,
**parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
**Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
**{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
**Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False

View file

@ -1,4 +1,6 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
)
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
class SQLite(Dialect):
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
**tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@ -46,6 +45,7 @@ class SQLite(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
TOKEN_MAPPING = {
@ -53,7 +53,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
class Generator(MySQL.Generator):
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@ -13,7 +15,7 @@ class StarRocks(MySQL):
}
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
**MySQL.Generator.TRANSFORMS, # type: ignore
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
@ -22,3 +24,4 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
TRANSFORMS.pop(exp.DateTrunc)

View file

@ -1,7 +1,7 @@
from sqlglot import exp
from __future__ import annotations
from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.generator import Generator
from sqlglot.parser import Parser
def _if_sql(self, expression):
@ -20,17 +20,17 @@ def _count_sql(self, expression):
class Tableau(Dialect):
class Generator(Generator):
class Generator(generator.Generator):
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.presto import Presto
@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
**Presto.Generator.TRANSFORMS, # type: ignore
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}

View file

@ -1,15 +1,22 @@
from __future__ import annotations
import re
from sqlglot import exp
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType
from sqlglot.generator import Generator
from sqlglot.helper import list_get
from sqlglot.parser import Parser
from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
from sqlglot.tokens import TokenType
FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
"w": "%A",
"month": "%B",
"mm": "%B",
"m": "%B",
}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
this=list_get(args, 1),
this=seq_get(args, 1),
format=exp.Literal.string(
format_time(
list_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
seq_get(args, 0).name or (TSQL.time_format if default is True else default),
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
)
),
)
@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def parse_format(args):
fmt = list_get(args, 1)
fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
return exp.NumberToStr(this=list_get(args, 0), format=fmt)
return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.TimeToStr(
this=list_get(args, 0),
this=seq_get(args, 0),
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@ -188,11 +197,11 @@ class TSQL(Dialect):
"Y": "%a %Y",
}
class Tokenizer(Tokenizer):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
**tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
@ -200,7 +209,6 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
"VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
@ -213,9 +221,9 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
}
class Parser(Parser):
class Parser(parser.Parser):
FUNCTIONS = {
**Parser.FUNCTIONS,
**parser.Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@ -243,14 +251,16 @@ class TSQL(Dialect):
this = self._parse_column()
# Retrieve length of datatype and override to default if not specified
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping:
raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val}"
)
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format
@ -272,9 +282,9 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
class Generator(Generator):
class Generator(generator.Generator):
TYPE_MAPPING = {
**Generator.TYPE_MAPPING,
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
@ -283,7 +293,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
**Generator.TRANSFORMS,
**generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),