1
0
Fork 0

Merging upstream version 16.2.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 16:00:51 +01:00
parent c12f551e31
commit 718a80b164
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
106 changed files with 41940 additions and 40162 deletions

View file

@ -7,6 +7,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
max_or_greatest,
min_or_least,
@ -103,16 +104,26 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
class BigQuery(Dialect):
unnest_column_only = True
time_mapping = {
"%M": "%-M",
"%d": "%-d",
"%m": "%-m",
"%y": "%-y",
"%H": "%-H",
"%I": "%-I",
"%S": "%-S",
"%j": "%-j",
UNNEST_COLUMN_ONLY = True
TIME_MAPPING = {
"%D": "%m/%d/%y",
}
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
"MON": "%b",
"MONTH": "%B",
"YYYY": "%Y",
"YY": "%y",
"HH": "%I",
"HH12": "%I",
"HH24": "%H",
"MI": "%M",
"SS": "%S",
"SSSSS": "%f",
"TZH": "%z",
}
class Tokenizer(tokens.Tokenizer):
@ -142,6 +153,7 @@ class BigQuery(Dialect):
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"RECORD": TokenType.STRUCT,
"TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@ -155,13 +167,21 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
"PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@ -172,15 +192,15 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
this=seq_get(args, 1), format=seq_get(args, 0)
"SPLIT": lambda args: exp.Split(
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
}
FUNCTION_PARSERS = {
@ -274,9 +294,18 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.RegexpExtract: lambda self, e: self.func(
"REGEXP_EXTRACT",
e.this,
e.expression,
e.args.get("position"),
e.args.get("occurrence"),
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@ -295,7 +324,6 @@ class BigQuery(Dialect):
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
}
TYPE_MAPPING = {
@ -315,6 +343,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",

View file

@ -21,8 +21,9 @@ def _lower_func(sql: str) -> str:
class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
STRICT_STRING_CONCAT = True
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
@ -163,11 +164,11 @@ class ClickHouse(Dialect):
return this
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
return super()._parse_position(haystack_first=True)
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.Expression:
def _parse_cte(self) -> exp.CTE:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
@ -187,17 +188,19 @@ class ClickHouse(Dialect):
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
is_global = self._match(TokenType.GLOBAL) and self._prev
kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
if kind_pre:
kind = self._match_set(self.JOIN_KINDS) and self._prev
side = self._match_set(self.JOIN_SIDES) and self._prev
return is_global, side, kind
return (
is_global,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
join = super()._parse_join(skip_join_token)
if join:
@ -205,9 +208,14 @@ class ClickHouse(Dialect):
return join
def _parse_function(
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
self,
functions: t.Optional[t.Dict[str, t.Callable]] = None,
anonymous: bool = False,
optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
func = super()._parse_function(functions, anonymous)
func = super()._parse_function(
functions=functions, anonymous=anonymous, optional_parens=optional_parens
)
if isinstance(func, exp.Anonymous):
params = self._parse_func_params(func)
@ -227,10 +235,12 @@ class ClickHouse(Dialect):
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
if self._match(TokenType.L_PAREN):
params = self._parse_csv(self._parse_lambda)
self._match_r_paren(this)
return params
return None
def _parse_quantile(self) -> exp.Quantile:
@ -247,12 +257,12 @@ class ClickHouse(Dialect):
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
) -> exp.Expression:
) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
return super()._parse_primary_key(
wrapped_optional=wrapped_optional or in_props, in_props=in_props
)
def _parse_on_property(self) -> t.Optional[exp.Property]:
def _parse_on_property(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match_text_seq("CLUSTER"):
this = self._parse_id_var()
@ -329,6 +339,16 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
return self.func(
"CONCAT",
*[
exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
for e in expression.expressions
],
)
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")

View file

@ -25,6 +25,8 @@ class Dialects(str, Enum):
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
MYSQL = "mysql"
@ -38,11 +40,9 @@ class Dialects(str, Enum):
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
TERADATA = "teradata"
class _Dialect(type):
@ -76,16 +76,19 @@ class _Dialect(type):
enum = Dialects.__members__.get(clsname.upper())
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
klass.time_trie = new_trie(klass.time_mapping)
klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
klass.FORMAT_TRIE = (
new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
)
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
klass.identifier_start, klass.identifier_end = list(
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
@ -99,43 +102,80 @@ class _Dialect(type):
(None, None),
)
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING)
klass.tokenizer_class.identifiers_can_start_with_digit = (
klass.identifiers_can_start_with_digit
)
dialect_properties = {
**{
k: v
for k, v in vars(klass).items()
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
},
"STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
}
# Pass required dialect properties to the tokenizer, parser and generator classes
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
for name, value in dialect_properties.items():
if hasattr(subclass, name):
setattr(subclass, name, value)
if not klass.STRICT_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
return klass
class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
identifiers_can_start_with_digit = False
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
# Determines the base index offset for arrays
INDEX_OFFSET = 0
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
time_mapping: t.Dict[str, str] = {}
# If true unnest table aliases are considered only as column aliases
UNNEST_COLUMN_ONLY = False
# autofilled
quote_start = None
quote_end = None
identifier_start = None
identifier_end = None
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
time_trie = None
inverse_time_mapping = None
inverse_time_trie = None
tokenizer_class = None
parser_class = None
generator_class = None
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
# Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False
# Determines how function names are going to be normalized
NORMALIZE_FUNCTIONS: bool | str = "upper"
# Indicates the default null ordering method to use if not explicitly set
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
# Custom time mappings in which the key represents dialect time format
# and the value represents a python time format
TIME_MAPPING: t.Dict[str, str] = {}
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
# Autofilled
tokenizer_class = Tokenizer
parser_class = Parser
generator_class = Generator
# A trie of the time_mapping keys
TIME_TRIE: t.Dict = {}
FORMAT_TRIE: t.Dict = {}
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect):
) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
format_time(
expression[1:-1], # the time formats are quoted
cls.time_mapping,
cls.time_trie,
)
# the time formats are quoted
format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
)
if expression and expression.is_string:
return exp.Literal.string(
format_time(
expression.this,
cls.time_mapping,
cls.time_trie,
)
)
return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class() # type: ignore
self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts) -> Parser:
return self.parser_class( # type: ignore
**{
"index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"null_ordering": self.null_ordering,
**opts,
},
)
return self.parser_class(**opts)
def generator(self, **opts) -> Generator:
return self.generator_class( # type: ignore
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"bit_start": self.bit_start,
"bit_end": self.bit_end,
"hex_start": self.hex_start,
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"raw_start": self.raw_start,
"raw_end": self.raw_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
**opts,
}
)
return self.generator_class(**opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this),
expression=expression.args["expression"],
)
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
)
@ -359,6 +355,7 @@ def var_map_sql(
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
return self.func(map_func_name, *args)
@ -381,7 +378,7 @@ def format_time_lambda(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
seq_get(args, 1)
or (Dialect[dialect].time_format if default is True else default or None)
or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
),
)
@ -437,9 +434,7 @@ def parse_date_delta_with_interval(
expression = exp.Literal.number(expression.this)
return expression_class(
this=args[0],
expression=expression,
unit=exp.Literal.string(interval.text("unit")),
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
return func
@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
)
@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
this, *rest_args = expression.expressions
for arg in rest_args:
this = exp.DPipe(this=this, expression=arg)
return self.sql(this)
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []

View file

@ -16,21 +16,10 @@ from sqlglot.dialects.dialect import (
)
def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.Var(this=expression.text("unit").upper() or "DAY")
unit = exp.var(expression.text("unit").upper() or "DAY")
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
@ -41,19 +30,19 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
if time_format == Drill.DATE_FORMAT:
return f"CAST({this} AS DATE)"
return f"TO_DATE({this}, {time_format})"
class Drill(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
time_mapping = {
TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
@ -93,6 +82,7 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -135,8 +125,8 @@ class Drill(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@ -154,7 +144,7 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}

View file

@ -56,11 +56,7 @@ def _sort_array_reverse(args: t.List) -> exp.Expression:
def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
)
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
@ -90,7 +86,7 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
null_ordering = "nulls_are_last"
NULL_ORDERING = "nulls_are_last"
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@ -118,6 +114,8 @@ class DuckDB(Dialect):
}
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
@ -127,10 +125,7 @@ class DuckDB(Dialect):
"DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
),
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
@ -191,8 +186,8 @@ class DuckDB(Dialect):
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
@ -242,11 +237,27 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()
if unit.startswith("week"):
multiplier = 7
if unit.startswith("quarter"):
multiplier = 90
if multiplier:
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
return super().interval_sql(expression)
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:

View file

@ -80,12 +80,12 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
return f"{diff_sql}{multiplier_sql}"
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
@ -113,7 +113,7 @@ def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> st
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS DATE)"
@ -121,7 +121,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS TIMESTAMP)"
@ -130,7 +130,7 @@ def _time_format(
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Hive.time_format:
if time_format == Hive.TIME_FORMAT:
return None
return time_format
@ -144,16 +144,16 @@ def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.time_format, Hive.date_format):
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
return f"TO_DATE({this}, {time_format})"
return f"TO_DATE({this})"
class Hive(Dialect):
alias_post_tablesample = True
identifiers_can_start_with_digit = True
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
time_mapping = {
TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
@ -184,9 +184,9 @@ class Hive(Dialect):
"EEEE": "%A",
}
date_format = "'yyyy-MM-dd'"
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
DATE_FORMAT = "'yyyy-MM-dd'"
DATEINT_FORMAT = "'yyyyMMdd'"
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
@ -224,9 +224,7 @@ class Hive(Dialect):
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"),
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@ -234,10 +232,7 @@ class Hive(Dialect):
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
expression=exp.Mul(
this=seq_get(args, 1),
expression=exp.Literal.number(-1),
),
expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
@ -349,8 +344,8 @@ class Hive(Dialect):
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
@ -415,10 +410,7 @@ class Hive(Dialect):
)
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(
properties,
prefix=self.seg("TBLPROPERTIES"),
)
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
def datatype_sql(self, expression: exp.DataType) -> str:
if (

View file

@ -94,10 +94,10 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
time_format = "'%Y-%m-%d %T'"
TIME_FORMAT = "'%Y-%m-%d %T'"
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
time_mapping = {
TIME_MAPPING = {
"%M": "%B",
"%c": "%-m",
"%e": "%-d",
@ -128,6 +128,7 @@ class MySQL(Dialect):
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"SEPARATOR": TokenType.SEPARATOR,
"ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@ -279,6 +280,16 @@ class MySQL(Dialect):
"SWAPS",
}
TYPE_TOKENS = {
*parser.Parser.TYPE_TOKENS,
TokenType.SET,
}
ENUM_TYPE_TOKENS = {
*parser.Parser.ENUM_TYPE_TOKENS,
TokenType.SET,
}
LOG_DEFAULTS_TO_LN = True
def _parse_show_mysql(
@ -372,12 +383,7 @@ class MySQL(Dialect):
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
collate=collate,
kind="NAMES",
)
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@ -472,9 +478,7 @@ class MySQL(Dialect):
def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
sql = self.sql(expression, arg)
if not sql:
return ""
return f" {prefix} {sql}"
return f" {prefix} {sql}" if sql else ""
def _oldstyle_limit_sql(self, expression: exp.Show) -> str:
limit = self.sql(expression, "limit")

View file

@ -24,21 +24,15 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
if self._match_text_seq("COLUMNS"):
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
return self.expression(
exp.XMLTable,
this=this,
passing=passing,
columns=columns,
by_ref=by_ref,
)
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
class Oracle(Dialect):
alias_post_tablesample = True
ALIAS_POST_TABLESAMPLE = True
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
TIME_MAPPING = {
"AM": "%p", # Meridian indicator with or without periods
"A.M.": "%p", # Meridian indicator with or without periods
"PM": "%p", # Meridian indicator with or without periods
@ -87,7 +81,7 @@ class Oracle(Dialect):
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
def _parse_hint(self) -> t.Optional[exp.Expression]:
def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
start = self._curr
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
@ -129,7 +123,7 @@ class Oracle(Dialect):
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.IfNull: rename_func("NVL"),
exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
@ -179,7 +173,6 @@ class Oracle(Dialect):
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,

View file

@ -183,9 +183,10 @@ def _to_timestamp(args: t.List) -> exp.Expression:
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
TIME_MAPPING = {
"AM": "%p",
"PM": "%p",
"D": "%u", # 1-based day of week
@ -241,7 +242,6 @@ class Postgres(Dialect):
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
"RETURNING": TokenType.RETURNING,
"REVOKE": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
@ -258,6 +258,7 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -268,6 +269,7 @@ class Postgres(Dialect):
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
}
FUNCTION_PARSERS = {
@ -303,7 +305,7 @@ class Postgres(Dialect):
value = self._parse_bitwise()
if part and part.is_string:
part = exp.Var(this=part.name)
part = exp.var(part.name)
return self.expression(exp.Extract, this=part, expression=value)
@ -328,6 +330,7 @@ class Postgres(Dialect):
**generator.Generator.TRANSFORMS,
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),

View file

@ -102,7 +102,7 @@ def _str_to_time_sql(
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.time_format, Presto.date_format):
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
@ -119,7 +119,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
exp.Literal.number(1),
exp.Literal.number(10),
),
Presto.date_format,
Presto.DATE_FORMAT,
)
return self.func(
@ -145,9 +145,7 @@ def _approx_percentile(args: t.List) -> exp.Expression:
)
if len(args) == 3:
return exp.ApproxQuantile(
this=seq_get(args, 0),
quantile=seq_get(args, 1),
accuracy=seq_get(args, 2),
this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2)
)
return exp.ApproxQuantile.from_arg_list(args)
@ -160,10 +158,8 @@ def _from_unixtime(args: t.List) -> exp.Expression:
minutes=seq_get(args, 2),
)
if len(args) == 2:
return exp.UnixToTime(
this=seq_get(args, 0),
zone=seq_get(args, 1),
)
return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1))
return exp.UnixToTime.from_arg_list(args)
@ -173,21 +169,17 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
unnest = exp.Unnest(expressions=[expression.this])
if expression.alias:
return exp.alias_(
unnest,
alias="_u",
table=[expression.alias],
copy=False,
)
return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
return unnest
return expression
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = MySQL.time_format
time_mapping = MySQL.time_mapping
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
TIME_FORMAT = MySQL.TIME_FORMAT
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@ -205,14 +197,10 @@ class Presto(Dialect):
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_DIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@ -225,9 +213,7 @@ class Presto(Dialect):
"NOW": exp.CurrentTimestamp.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
@ -242,7 +228,7 @@ class Presto(Dialect):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
IS_BOOL = False
IS_BOOL_ALLOWED = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@ -284,10 +270,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: _encode_sql,
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
@ -322,7 +308,7 @@ class Presto(Dialect):
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
@ -367,8 +353,16 @@ class Presto(Dialect):
to = target_type.copy()
if target_type is start.to:
end = exp.Cast(this=end, to=to)
end = exp.cast(end, to)
else:
start = exp.Cast(this=start, to=to)
start = exp.cast(start, to)
return self.func("SEQUENCE", start, end, step)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:
return [
self.sql(expression, "offset"),
self.sql(limit),
]

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@ -14,9 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
**Postgres.time_mapping,
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
**Postgres.TIME_MAPPING,
"MON": "%b",
"HH": "%H",
}
@ -51,7 +51,7 @@ class Redshift(Postgres):
and this.expressions
and this.expressions[0].this == exp.column("MAX")
):
this.set("expressions", [exp.Var(this="MAX")])
this.set("expressions", [exp.var("MAX")])
return this
@ -94,6 +94,7 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@ -106,6 +107,7 @@ class Redshift(Postgres):
exp.FromBase: rename_func("STRTOL"),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
@ -170,6 +172,6 @@ class Redshift(Postgres):
precision = expression.args.get("expressions")
if not precision:
expression.append("expressions", exp.Var(this="MAX"))
expression.append("expressions", exp.var("MAX"))
return super().datatype_sql(expression)

View file

@ -167,10 +167,10 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:
class Snowflake(Dialect):
null_ordering = "nulls_are_large"
time_format = "'yyyy-mm-dd hh24:mi:ss'"
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
TIME_MAPPING = {
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
@ -210,14 +210,10 @@ class Snowflake(Dialect):
"CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
@ -246,9 +242,7 @@ class Snowflake(Dialect):
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
expressions=[path],
exp.Bracket, this=this, expressions=[path]
),
}
@ -275,6 +269,7 @@ class Snowflake(Dialect):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
COMMENTS = ["--", "//", ("/*", "*/")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,

View file

@ -38,7 +38,7 @@ def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.date_format:
if time_format == Hive.DATE_FORMAT:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
@ -133,13 +133,13 @@ class Spark2(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"BOOLEAN": _parse_as_cast("boolean"),
"DATE": _parse_as_cast("date"),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"INT": _parse_as_cast("int"),
@ -162,11 +162,9 @@ class Spark2(Hive):
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
exp.Drop,
this=self._parse_schema(),
kind="COLUMNS",
exp.Drop, this=self._parse_schema(), kind="COLUMNS"
)
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:

View file

@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
count_if_to_sum,
no_ilike_sql,
no_pivot_sql,
@ -62,10 +63,6 @@ class SQLite(Dialect):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@ -100,6 +97,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Concat: concat_to_dpipe_sql,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
@ -116,6 +114,7 @@ class SQLite(Dialect):
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Pivot: no_pivot_sql,
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),

View file

@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
@ -11,6 +11,7 @@ class Tableau(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Coalesce: rename_func("IFNULL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
@ -25,9 +26,6 @@ class Tableau(Dialect):
false = self.sql(expression, "false")
return f"IF {this} THEN {true} ELSE {false} END"
def coalesce_sql(self, expression: exp.Coalesce) -> str:
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
def count_sql(self, expression: exp.Count) -> str:
this = expression.this
if isinstance(this, exp.Distinct):

View file

@ -1,18 +1,32 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
max_or_greatest,
min_or_least,
)
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
class Teradata(Dialect):
TIME_MAPPING = {
"Y": "%Y",
"YYYY": "%Y",
"YY": "%y",
"MMMM": "%B",
"MMM": "%b",
"DD": "%d",
"D": "%-d",
"HH": "%H",
"H": "%-H",
"MM": "%M",
"M": "%-M",
"SS": "%S",
"S": "%-S",
"SSSSSS": "%f",
"E": "%a",
"EE": "%a",
"EEE": "%a",
"EEEE": "%A",
}
class Tokenizer(tokens.Tokenizer):
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
KEYWORDS = {
@ -31,7 +45,7 @@ class Teradata(Dialect):
"ST_GEOMETRY": TokenType.GEOMETRY,
}
# teradata does not support % for modulus
# Teradata does not support % as a modulo operator
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
SINGLE_TOKENS.pop("%")
@ -101,7 +115,7 @@ class Teradata(Dialect):
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
def _parse_update(self) -> exp.Expression:
def _parse_update(self) -> exp.Update:
return self.expression(
exp.Update,
**{ # type: ignore
@ -122,14 +136,6 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
def _parse_cast(self, strict: bool) -> exp.Expression:
cast = t.cast(exp.Cast, super()._parse_cast(strict))
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
return format_time_lambda(exp.TimeToStr, "teradata")(
[cast.this, self._parse_string()]
)
return cast
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
@ -151,7 +157,7 @@ class Teradata(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}

View file

@ -64,9 +64,9 @@ def _format_time_lambda(
format=exp.Literal.string(
format_time(
args[0].name,
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
else TSQL.TIME_MAPPING,
)
),
)
@ -86,9 +86,9 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(
this=args[0],
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.time_mapping)
else format_time(fmt.name, TSQL.TIME_MAPPING)
),
)
@ -138,7 +138,7 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
)
)
)
@ -166,10 +166,10 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
NULL_ORDERING = "nulls_are_small"
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
time_mapping = {
TIME_MAPPING = {
"year": "%Y",
"qq": "%q",
"q": "%q",
@ -213,7 +213,7 @@ class TSQL(Dialect):
"yy": "%y",
}
convert_format_mapping = {
CONVERT_FORMAT_MAPPING = {
"0": "%b %d %Y %-I:%M%p",
"1": "%m/%d/%y",
"2": "%y.%m.%d",
@ -253,8 +253,8 @@ class TSQL(Dialect):
"120": "%Y-%m-%d %H:%M:%S",
"121": "%Y-%m-%d %H:%M:%S.%f",
}
# not sure if complete
format_time_mapping = {
FORMAT_TIME_MAPPING = {
"y": "%B %Y",
"d": "%m/%d/%Y",
"H": "%-H",
@ -312,9 +312,7 @@ class TSQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
position=seq_get(args, 2),
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@ -363,6 +361,8 @@ class TSQL(Dialect):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
CONCAT_NULL_OUTPUTS_STRING = True
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@ -400,7 +400,7 @@ class TSQL(Dialect):
table.set("system_time", self._parse_system_time())
return table
def _parse_returns(self) -> exp.Expression:
def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
returns.set("table", table)
@ -423,12 +423,12 @@ class TSQL(Dialect):
format_val = self._parse_number()
format_val_name = format_val.name if format_val else ""
if format_val_name not in TSQL.convert_format_mapping:
if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val_name}"
)
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:

View file

@ -151,6 +151,7 @@ ENV = {
"CAST": cast,
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
"SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
@ -159,7 +160,6 @@ ENV = {
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IFNULL": lambda e, alt: alt if e is None else e,
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,

View file

@ -394,7 +394,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
lambda n: exp.Var(this=n.name)
lambda n: exp.var(n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
)

View file

@ -1500,6 +1500,7 @@ class Index(Expression):
arg_types = {
"this": False,
"table": False,
"using": False,
"where": False,
"columns": False,
"unique": False,
@ -1623,7 +1624,7 @@ class Lambda(Expression):
class Limit(Expression):
arg_types = {"this": False, "expression": True}
arg_types = {"this": False, "expression": True, "offset": False}
class Literal(Condition):
@ -1869,6 +1870,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
class ToTableProperty(Property):
arg_types = {"this": True}
class ExecuteAsProperty(Property):
arg_types = {"this": True}
@ -3072,12 +3077,35 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select:
"""
Set hints for this expression.
Examples:
>>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark")
'SELECT /*+ BROADCAST(y) */ x FROM tbl'
Args:
hints: The SQL code strings to parse as the hints.
If an `Expression` instance is passed, it will be used as-is.
dialect: The dialect used to parse the hints.
copy: If `False`, modify this expression instance in-place.
Returns:
The modified expression.
"""
inst = _maybe_copy(self, copy)
inst.set(
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
)
return inst
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@ -3244,6 +3272,7 @@ class DataType(Expression):
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
ENUM = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
@ -3284,6 +3313,7 @@ class DataType(Expression):
OBJECT = auto()
ROWVERSION = auto()
SERIAL = auto()
SET = auto()
SMALLINT = auto()
SMALLMONEY = auto()
SMALLSERIAL = auto()
@ -3334,6 +3364,7 @@ class DataType(Expression):
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
TEMPORAL_TYPES = {
Type.TIME,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
@ -3342,6 +3373,8 @@ class DataType(Expression):
Type.DATETIME64,
}
META_TYPES = {"UNKNOWN", "NULL"}
@classmethod
def build(
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
@ -3349,8 +3382,9 @@ class DataType(Expression):
from sqlglot import parse_one
if isinstance(dtype, str):
if dtype.upper() in cls.Type.__members__:
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
upper = dtype.upper()
if upper in DataType.META_TYPES:
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
@ -3483,6 +3517,10 @@ class Dot(Binary):
def name(self) -> str:
return self.expression.name
@property
def output_name(self) -> str:
return self.name
@classmethod
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
@ -3502,6 +3540,10 @@ class DPipe(Binary):
pass
class SafeDPipe(DPipe):
pass
class EQ(Binary, Predicate):
pass
@ -3615,6 +3657,10 @@ class Not(Unary):
class Paren(Unary):
arg_types = {"this": True, "with": False}
@property
def output_name(self) -> str:
return self.this.name
class Neg(Unary):
pass
@ -3904,6 +3950,7 @@ class Ceil(Func):
class Coalesce(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
_sql_names = ["COALESCE", "IFNULL", "NVL"]
class Concat(Func):
@ -3911,12 +3958,17 @@ class Concat(Func):
is_var_len_args = True
class SafeConcat(Concat):
pass
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
class Count(AggFunc):
arg_types = {"this": False}
arg_types = {"this": False, "expressions": False}
is_var_len_args = True
class CountIf(AggFunc):
@ -4049,6 +4101,11 @@ class DateToDi(Func):
pass
class Date(Func):
arg_types = {"expressions": True}
is_var_len_args = True
class Day(Func):
pass
@ -4102,11 +4159,6 @@ class If(Func):
arg_types = {"this": True, "true": True, "false": False}
class IfNull(Func):
arg_types = {"this": True, "expression": False}
_sql_names = ["IFNULL", "NVL"]
class Initcap(Func):
arg_types = {"this": True, "expression": False}
@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression: Expression) -> t.List[str]:
def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
"""
Return all table names referenced through columns in an expression.
Example:
>>> import sqlglot
>>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))
['c', 'a']
>>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")))
['a', 'c']
Args:
expression: expression to find table names.
exclude: a table name to exclude
Returns:
A list of unique names.
"""
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
return {
table
for table in (column.table for column in expression.find_all(Column))
if table and table != exclude
}
def table_name(table: Table | str) -> str:
@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str:
return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part)
def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
"""Replace all tables in expression according to the mapping.
Args:
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
copy: whether or not to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
)
return node
return expression.transform(_replace_tables)
return expression.transform(_replace_tables, copy=copy)
def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:

View file

@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot")
class Generator:
"""
Generator interprets the given syntax tree and produces a SQL string as an output.
Generator converts a given syntax tree to the corresponding SQL string.
Args:
time_mapping (dict): the dictionary of custom time mappings in which the key
represents a python time format and the output the target time format
time_trie (trie): a trie of the time_mapping keys
pretty (bool): if set to True the returned string will be formatted. Default: False.
quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
normalize_functions (str): normalize function names, "upper", "lower", or None
Default: "upper"
alias_post_tablesample (bool): if the table alias comes after tablesample
Default: False
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
Default: False
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
unsupported expressions. Default ErrorLevel.WARN.
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
Default: "nulls_are_small"
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
pretty: Whether or not to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
normalize: Whether or not to normalize identifiers to lowercase.
Default: False.
pad: Determines the pad size in a formatted string.
Default: 2.
indent: Determines the indentation size in a formatted string.
Default: 2.
normalize_functions: Whether or not to normalize all function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
Default ErrorLevel.WARN.
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma (bool): if the the comma is leading or trailing in select statements
leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
@ -86,6 +71,7 @@ class Generator:
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@ -138,15 +124,24 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
# Whether a table is allowed to be renamed with a db
# Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
# The string used for creating index on a table
# The string used for creating an index on a table
INDEX_ON = "ON"
# Whether or not join hints should be generated
JOIN_HINTS = True
# Whether or not table hints should be generated
TABLE_HINTS = True
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -228,6 +223,7 @@ class Generator:
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
@ -235,128 +231,110 @@ class Generator:
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
JOIN_HINTS = True
TABLE_HINTS = True
IS_BOOL = True
# Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Select,
exp.From,
exp.Where,
exp.With,
)
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
exp.Literal,
exp.Neg,
exp.Paren,
)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
IDENTIFIERS_CAN_START_WITH_DIGIT = False
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
STRING_ESCAPE = "'"
IDENTIFIER_ESCAPE = '"'
# Delimiters for bit, hex, byte and raw literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
RAW_START: t.Optional[str] = None
RAW_END: t.Optional[str] = None
__slots__ = (
"time_mapping",
"time_trie",
"pretty",
"quote_start",
"quote_end",
"identifier_start",
"identifier_end",
"bit_start",
"bit_end",
"hex_start",
"hex_end",
"byte_start",
"byte_end",
"raw_start",
"raw_end",
"identify",
"normalize",
"string_escape",
"identifier_escape",
"pad",
"index_offset",
"unnest_column_only",
"alias_post_tablesample",
"identifiers_can_start_with_digit",
"_indent",
"normalize_functions",
"unsupported_level",
"unsupported_messages",
"null_ordering",
"max_unsupported",
"_indent",
"leading_comma",
"max_text_width",
"comments",
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
"_leading_comma",
"_max_text_width",
"_comments",
"_cache",
)
def __init__(
self,
time_mapping=None,
time_trie=None,
pretty=None,
quote_start=None,
quote_end=None,
identifier_start=None,
identifier_end=None,
bit_start=None,
bit_end=None,
hex_start=None,
hex_end=None,
byte_start=None,
byte_end=None,
raw_start=None,
raw_end=None,
identify=False,
normalize=False,
string_escape=None,
identifier_escape=None,
pad=2,
indent=2,
index_offset=0,
unnest_column_only=False,
alias_post_tablesample=False,
identifiers_can_start_with_digit=False,
normalize_functions="upper",
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
max_unsupported=3,
leading_comma=False,
max_text_width=80,
comments=True,
pretty: t.Optional[bool] = None,
identify: str | bool = False,
normalize: bool = False,
pad: int = 2,
indent: int = 2,
normalize_functions: t.Optional[str | bool] = None,
unsupported_level: ErrorLevel = ErrorLevel.WARN,
max_unsupported: int = 3,
leading_comma: bool = False,
max_text_width: int = 80,
comments: bool = True,
):
import sqlglot
self.time_mapping = time_mapping or {}
self.time_trie = time_trie
self.pretty = pretty if pretty is not None else sqlglot.pretty
self.quote_start = quote_start or "'"
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
self.identifier_end = identifier_end or '"'
self.bit_start = bit_start
self.bit_end = bit_end
self.hex_start = hex_start
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
self.raw_start = raw_start
self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
self.identifier_escape = identifier_escape or '"'
self.pad = pad
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
self.alias_post_tablesample = alias_post_tablesample
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
self.normalize_functions = normalize_functions
self.unsupported_level = unsupported_level
self.unsupported_messages = []
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
self._escaped_quote_end = self.string_escape + self.quote_end
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
self._cache = None
self.unsupported_level = unsupported_level
self.max_unsupported = max_unsupported
self.leading_comma = leading_comma
self.max_text_width = max_text_width
self.comments = comments
# This is both a Dialect property and a Generator argument, so we prioritize the latter
self.normalize_functions = (
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
)
self.unsupported_messages: t.List[str] = []
self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
@ -364,17 +342,19 @@ class Generator:
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Generates the SQL string corresponding to the given syntax tree.
Args
expression: the syntax tree.
cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Args:
expression: The syntax tree.
cache: An optional sql string cache. This leverages the hash of an Expression
which can be slow to compute, so only use it if you set _hash on each node.
Returns
the SQL string.
Returns:
The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
@ -414,7 +394,11 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
if self.comments
else None
)
if not comments or isinstance(expression, exp.Binary):
return sql
@ -454,7 +438,7 @@ class Generator:
return result
def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper":
if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
@ -522,7 +506,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
@ -770,25 +754,25 @@ class Generator:
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
if self.bit_start:
return f"{self.bit_start}{this}{self.bit_end}"
if self.BIT_START:
return f"{self.BIT_START}{this}{self.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
if self.hex_start:
return f"{self.hex_start}{this}{self.hex_end}"
if self.HEX_START:
return f"{self.HEX_START}{this}{self.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
if self.byte_start:
return f"{self.byte_start}{this}{self.byte_end}"
if self.BYTE_START:
return f"{self.BYTE_START}{this}{self.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
if self.raw_start:
return f"{self.raw_start}{expression.name}{self.raw_end}"
if self.RAW_START:
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
@ -883,24 +867,27 @@ class Generator:
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table} " if table else ""
using = self.sql(expression, "using")
using = f"USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@ -1197,7 +1184,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
if self.alias_post_tablesample and expression.this.alias:
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@ -1372,7 +1359,15 @@ class Generator:
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
args = ", ".join(
sql
for sql in (
self.sql(expression, "offset"),
self.sql(expression, "expression"),
)
if sql
)
return f"{this}{self.seg('LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@ -1418,10 +1413,10 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
text = text.replace(self.quote_end, self._escaped_quote_end)
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
@ -1463,9 +1458,9 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
nulls_are_large = self.null_ordering == "nulls_are_large"
nulls_are_small = self.null_ordering == "nulls_are_small"
nulls_are_last = self.null_ordering == "nulls_are_last"
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
@ -1521,7 +1516,7 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
@ -1540,12 +1535,19 @@ class Generator:
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
*self.offset_limit_modifiers(expression, fetch, limit),
*self.after_limit_modifiers(expression),
sep="",
)
def offset_limit_modifiers(
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
) -> t.List[str]:
return [
self.sql(expression, "offset") if fetch else self.sql(limit),
self.sql(limit) if fetch else self.sql(expression, "offset"),
]
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
@ -1634,7 +1636,7 @@ class Generator:
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.unnest_column_only:
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@ -1697,7 +1699,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@ -1729,7 +1731,7 @@ class Generator:
statements.append("END")
if self.pretty and self.text_width(statements) > self._max_text_width:
if self.pretty and self.text_width(statements) > self.max_text_width:
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
@ -1759,10 +1761,11 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
expressions = expression.expressions
if self.STRICT_STRING_CONCAT:
expressions = (exp.cast(e, "text") for e in expressions)
return self.func("CONCAT", *expressions)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
@ -1785,9 +1788,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
modifier = expression.args.get("modifier")
@ -1798,7 +1799,6 @@ class Generator:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
expressions = self.expressions(expression)
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
@ -1811,7 +1811,11 @@ class Generator:
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
return self.func(
"JSON_OBJECT",
*expression.expressions,
suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
)
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
@ -1930,7 +1934,7 @@ class Generator:
for i, e in enumerate(expression.flatten(unnest=False))
)
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
@ -2093,6 +2097,11 @@ class Generator:
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
if self.STRICT_STRING_CONCAT:
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.dpipe_sql(expression)
def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
@ -2127,7 +2136,7 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
@ -2197,12 +2206,18 @@ class Generator:
return self.func(expression.sql_name(), *args)
def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
return f"{self.normalize_func(name)}({self.format_args(*args)})"
def func(
self,
name: str,
*args: t.Optional[exp.Expression | str],
prefix: str = "(",
suffix: str = ")",
) -> str:
return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
@ -2210,7 +2225,9 @@ class Generator:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
return format_time(
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
)
def expressions(
self,
@ -2242,7 +2259,7 @@ class Generator:
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
if self._leading_comma:
if self.leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(

View file

@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
return expression
def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
result = []
def visit(node: T, visited: t.Set[T]) -> None:
if node in result:
return
if node in visited:
for node, deps in tuple(dag.items()):
for dep in deps:
if not dep in dag:
dag[dep] = set()
while dag:
current = {node for node, deps in dag.items() if not deps}
if not current:
raise ValueError("Cycle error")
visited.add(node)
for node in current:
dag.pop(node)
for dep in dag.get(node, []):
visit(dep, visited)
for deps in dag.values():
deps -= current
visited.remove(node)
result.append(node)
for node in dag:
visit(node, set())
result.extend(sorted(current)) # type: ignore
return result

View file

@ -1,13 +1,25 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
) -> E:
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
Assumes that we've already executed the optimizer's qualify_columns step.
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot
@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
<Type.DOUBLE: 'DOUBLE'>
Args:
expression (sqlglot.Expression): Expression to annotate.
schema (dict|sqlglot.optimizer.Schema): Database schema.
annotators (dict): Maps expression type to corresponding annotation function.
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
expression: Expression to annotate.
schema: Database schema.
annotators: Maps expression type to corresponding annotation function.
coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
The expression annotated with types.
"""
schema = ensure_schema(schema)
@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
class TypeAnnotator:
ANNOTATORS = {
**{
expr_type: lambda self, expr: self._annotate_unary(expr)
for expr_type in subclasses(exp.__name__, exp.Unary)
},
**{
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.Sum: lambda self, expr: self._annotate_by_args(
expr, "this", "expressions", promote=True
),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATETIME
),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.BOOLEAN
),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.StrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DATE
),
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.TIMESTAMP
),
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.VariancePop: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.DOUBLE
),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
return lambda self, e: self._annotate_with_type(e, data_type)
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
COERCES_TO = {
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
exp.DataType.Type.NCHAR: {
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
# Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
# https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
text_precedence = (
exp.DataType.Type.TEXT,
},
exp.DataType.Type.CHAR: {
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
exp.DataType.Type.TEXT,
},
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
exp.DataType.Type.BIGINT: {
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.CHAR,
)
numeric_precedence = (
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.INT: {
exp.DataType.Type.FLOAT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.SMALLINT: {
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
exp.DataType.Type.TINYINT: {
exp.DataType.Type.SMALLINT,
exp.DataType.Type.INT,
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
exp.DataType.Type.FLOAT,
exp.DataType.Type.DOUBLE,
},
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
exp.DataType.Type.TIMESTAMP: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TINYINT,
)
timelike_precedence = (
exp.DataType.Type.TIMESTAMPLTZ,
},
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
exp.DataType.Type.DATETIME,
exp.DataType.Type.DATE,
)
for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
coerces_to = set()
for data_type in type_precedence:
klass.COERCES_TO[data_type] = coerces_to.copy()
coerces_to |= {data_type}
return klass
class TypeAnnotator(metaclass=_TypeAnnotator):
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
exp.DataType.Type.BIGINT: {
exp.ApproxDistinct,
exp.ArraySize,
exp.Count,
exp.Length,
},
exp.DataType.Type.BOOLEAN: {
exp.Between,
exp.Boolean,
exp.In,
exp.RegexpLike,
},
exp.DataType.Type.DATE: {
exp.DataType.Type.DATETIME,
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
exp.DataType.Type.TIMESTAMPLTZ,
exp.CurrentDate,
exp.Date,
exp.DateAdd,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
exp.TsOrDsToDate,
},
exp.DataType.Type.DATETIME: {
exp.CurrentDatetime,
exp.DatetimeAdd,
exp.DatetimeSub,
},
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
exp.Exp,
exp.Ln,
exp.Log,
exp.Log2,
exp.Log10,
exp.Pow,
exp.Quantile,
exp.Round,
exp.SafeDivide,
exp.Sqrt,
exp.Stddev,
exp.StddevPop,
exp.StddevSamp,
exp.Variance,
exp.VariancePop,
},
exp.DataType.Type.INT: {
exp.Ceil,
exp.DateDiff,
exp.DatetimeDiff,
exp.Extract,
exp.TimestampDiff,
exp.TimeDiff,
exp.DateToDi,
exp.Floor,
exp.Levenshtein,
exp.StrPosition,
exp.TsOrDiToDi,
},
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
exp.StrToTime,
exp.TimeAdd,
exp.TimeStrToTime,
exp.TimeSub,
exp.TimestampAdd,
exp.TimestampSub,
exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
exp.Day,
exp.Month,
exp.Week,
exp.Year,
},
exp.DataType.Type.VARCHAR: {
exp.ArrayConcat,
exp.Concat,
exp.ConcatWs,
exp.DateToDateStr,
exp.GroupConcat,
exp.Initcap,
exp.Lower,
exp.SafeConcat,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
exp.Trim,
exp.TsOrDsToDateStr,
exp.UnixToStr,
exp.UnixToTimeStr,
exp.Upper,
},
}
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
ANNOTATORS = {
**{
expr_type: lambda self, e: self._annotate_unary(e)
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
**{
expr_type: lambda self, e: self._annotate_binary(e)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
**{
expr_type: _annotate_with_type_lambda(data_type)
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
def __init__(self, schema=None, annotators=None, coerces_to=None):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
def annotate(self, expression: E) -> E:
for scope in traverse_scope(expression):
selects = {}
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
values = source.expression.expressions[0].expressions
if not values:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source and col.table in selects and col.name in selects[col.table]:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
values,
)
}
else:
selects[name] = {
select.alias_or_name: select for select in source.expression.selects
}
# First annotate the current scope's column references
for col in scope.columns:
if not col.table:
continue
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source and col.table in selects and col.name in selects[col.table]:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
def _maybe_annotate(self, expression: E) -> E:
if expression.type:
return expression # We've already inferred the expression's type
@ -312,13 +290,15 @@ class TypeAnnotator:
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
def _annotate_args(self, expression):
def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
def _maybe_coerce(self, type1, type2):
def _maybe_coerce(
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
@ -330,9 +310,14 @@ class TypeAnnotator:
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
def _annotate_binary(self, expression):
# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
@ -354,7 +339,8 @@ class TypeAnnotator:
return expression
def _annotate_unary(self, expression):
@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
@ -364,7 +350,8 @@ class TypeAnnotator:
return expression
def _annotate_literal(self, expression):
@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
@ -374,13 +361,16 @@ class TypeAnnotator:
return expression
def _annotate_with_type(self, expression, target_type):
@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
return self._annotate_args(expression)
def _annotate_by_args(self, expression, *args, promote=False):
@t.no_type_check
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
self._annotate_args(expression)
expressions = []
expressions: t.List[exp.Expression] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)

View file

@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
node = exp.Concat(expressions=[node.left, node.right])
return node

View file

@ -32,7 +32,7 @@ def eliminate_joins(expression):
# Reverse the joins so we can remove chains of unused joins
for join in reversed(joins):
alias = join.this.alias_or_name
alias = join.alias_or_name
if _should_eliminate_join(scope, join, alias):
join.pop()
scope.remove_source(alias)
@ -126,7 +126,7 @@ def join_condition(join):
tuple[list[str], list[str], exp.Expression]:
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
name = join.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
source_key = []
join_key = []

View file

@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
alias(source, source.name or source.alias, table=True),
alias(source, source.alias_or_name, table=True),
copy=False,
)
.subquery(source.alias, copy=False)

View file

@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
if not isinstance(from_or_join, exp.Join):
return False
alias = from_or_join.this.alias_or_name
alias = from_or_join.alias_or_name
on = from_or_join.args.get("on")
if not on:
@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
"""
new_joins = []
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
for subquery in comma_joins:
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
joins = inner_scope.expression.args.get("joins") or []
for join in joins:
@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if source == from_or_join.alias_or_name:
break
if set(exp.column_table_names(where.this)) <= sources:
if exp.column_table_names(where.this) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):

View file

@ -1,3 +1,7 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.helper import tsort
@ -13,25 +17,28 @@ def optimize_joins(expression):
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
name = join.this.alias_or_name
tables = other_table_names(join, name)
tables = other_table_names(join)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
cross_joins.append((name, join))
cross_joins.append((join.alias_or_name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
if isinstance(on, exp.Connector):
if len(other_table_names(dep)) < 2:
continue
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
@ -47,17 +54,12 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
for name, join in joins.items():
dag[name] = other_table_names(join, name)
joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {name: other_table_names(join) for name, join in joins.items()}
parent.set(
"joins",
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
)
return expression
@ -75,9 +77,6 @@ def normalize(expression):
return expression
def other_table_names(join, exclude):
return [
name
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]
def other_table_names(join: exp.Join) -> t.Set[str]:
on = join.args.get("on")
return exp.column_table_names(on, join.alias_or_name) if on else set()

View file

@ -78,7 +78,7 @@ def optimize(
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False, # this happens in canonicalize
"quote_identifiers": False,
**kwargs,
}

View file

@ -41,7 +41,7 @@ def pushdown_predicates(expression):
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
name = join.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
pushdown_tables = set()
for a in predicates:
a_tables = set(exp.column_table_names(a))
a_tables = exp.column_table_names(a)
for b in predicates:
a_tables &= set(exp.column_table_names(b))
a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
for table in tables:
for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down

View file

@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
def qualify_columns(
expression: exp.Expression,
schema: dict | Schema,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
if not using:
continue
join_table = join.this.alias_or_name
join_table = join.alias_or_name
columns = {}

View file

@ -91,11 +91,13 @@ def qualify_tables(
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
table_alias = udtf.args.get("alias") or exp.TableAlias(
this=exp.to_identifier(next_alias_name())
)
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_alias_name())
table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))

View file

@ -620,7 +620,7 @@ def _traverse_tables(scope):
table_name = expression.name
source_name = expression.alias_or_name
if table_name in scope.sources:
if table_name in scope.sources and not expression.db:
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")

File diff suppressed because it is too large Load diff

View file

@ -302,7 +302,7 @@ class Join(Step):
for join in joins:
source_key, join_key, condition = join_condition(join)
step.joins[join.this.alias_or_name] = {
step.joins[join.alias_or_name] = {
"side": join.side, # type: ignore
"join_key": join_key,
"source_key": source_key,

View file

@ -285,8 +285,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper(), dialect=dialect)
raise SchemaError(f"Unknown column type '{column_type}'")
return exp.DataType.build("unknown")
def _normalize(self, schema: t.Dict) -> t.Dict:

View file

@ -144,6 +144,7 @@ class TokenType(AutoName):
VARIANT = auto()
OBJECT = auto()
INET = auto()
ENUM = auto()
# keywords
ALIAS = auto()
@ -346,6 +347,7 @@ class Token:
col: The column that the token ends on.
start: The start index of the token.
end: The ending index of the token.
comments: The comments to attach to the token.
"""
self.token_type = token_type
self.text = text
@ -391,12 +393,15 @@ class _Tokenizer(type):
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
)
klass._COMMENTS = {
**dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
),
"{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects
}
klass.KEYWORD_TRIE = new_trie(
klass._KEYWORD_TRIE = new_trie(
key.upper()
for key in (
*klass.KEYWORDS,
@ -456,20 +461,22 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
# Autofilled
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
_KEYWORD_TRIE: t.Dict = {}
KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
KEYWORDS: t.Dict[str, TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
"{{+": TokenType.BLOCK_START,
"{{-": TokenType.BLOCK_START,
"+}}": TokenType.BLOCK_END,
"-}}": TokenType.BLOCK_END,
**{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")},
**{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")},
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@ -594,6 +601,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
"REPLACE": TokenType.REPLACE,
"RETURNING": TokenType.RETURNING,
"REFERENCES": TokenType.REFERENCES,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
@ -732,8 +740,7 @@ class Tokenizer(metaclass=_Tokenizer):
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE: t.Dict = {} # autofilled
COMMENTS = ["--", ("/*", "*/")]
__slots__ = (
"sql",
@ -748,7 +755,6 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"identifiers_can_start_with_digit",
)
def __init__(self) -> None:
@ -894,7 +900,7 @@ class Tokenizer(metaclass=_Tokenizer):
char = chars
prev_space = False
skip = False
trie = self.KEYWORD_TRIE
trie = self._KEYWORD_TRIE
single_token = char in self.SINGLE_TOKENS
while chars:
@ -994,7 +1000,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
elif self._peek == "." and not decimal:
after = self.peek(1)
if after.isdigit() or not after.strip():
if after.isdigit() or not after.isalpha():
decimal = True
self._advance()
else:
@ -1013,13 +1019,13 @@ class Tokenizer(metaclass=_Tokenizer):
literal += self._peek.upper()
self._advance()
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, ""))
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
elif self.identifiers_can_start_with_digit: # type: ignore
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)