Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
c12f551e31
commit
718a80b164
106 changed files with 41940 additions and 40162 deletions
|
@ -7,6 +7,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
|
|||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
datestrtodate_sql,
|
||||
format_time_lambda,
|
||||
inline_array_sql,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
|
@ -103,16 +104,26 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
unnest_column_only = True
|
||||
time_mapping = {
|
||||
"%M": "%-M",
|
||||
"%d": "%-d",
|
||||
"%m": "%-m",
|
||||
"%y": "%-y",
|
||||
"%H": "%-H",
|
||||
"%I": "%-I",
|
||||
"%S": "%-S",
|
||||
"%j": "%-j",
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"%D": "%m/%d/%y",
|
||||
}
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"DD": "%d",
|
||||
"MM": "%m",
|
||||
"MON": "%b",
|
||||
"MONTH": "%B",
|
||||
"YYYY": "%Y",
|
||||
"YY": "%y",
|
||||
"HH": "%I",
|
||||
"HH12": "%I",
|
||||
"HH24": "%H",
|
||||
"MI": "%M",
|
||||
"SS": "%S",
|
||||
"SSSSS": "%f",
|
||||
"TZH": "%z",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
|
@ -142,6 +153,7 @@ class BigQuery(Dialect):
|
|||
"FLOAT64": TokenType.DOUBLE,
|
||||
"INT64": TokenType.BIGINT,
|
||||
"RECORD": TokenType.STRUCT,
|
||||
"TIMESTAMP": TokenType.TIMESTAMPTZ,
|
||||
"NOT DETERMINISTIC": TokenType.VOLATILE,
|
||||
"UNKNOWN": TokenType.NULL,
|
||||
}
|
||||
|
@ -155,13 +167,21 @@ class BigQuery(Dialect):
|
|||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"DATE_TRUNC": lambda args: exp.DateTrunc(
|
||||
unit=exp.Literal.string(str(seq_get(args, 1))),
|
||||
this=seq_get(args, 0),
|
||||
),
|
||||
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
|
||||
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
|
||||
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
|
||||
[seq_get(args, 1), seq_get(args, 0)]
|
||||
),
|
||||
"PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")(
|
||||
[seq_get(args, 1), seq_get(args, 0)]
|
||||
),
|
||||
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
|
||||
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
|
||||
this=seq_get(args, 0),
|
||||
|
@ -172,15 +192,15 @@ class BigQuery(Dialect):
|
|||
if re.compile(str(seq_get(args, 1))).groups == 1
|
||||
else None,
|
||||
),
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
|
||||
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
"PARSE_TIMESTAMP": lambda args: exp.StrToTime(
|
||||
this=seq_get(args, 1), format=seq_get(args, 0)
|
||||
"SPLIT": lambda args: exp.Split(
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1) or exp.Literal.string(","),
|
||||
),
|
||||
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
|
||||
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
|
||||
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
|
||||
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -274,9 +294,18 @@ class BigQuery(Dialect):
|
|||
exp.IntDiv: rename_func("DIV"),
|
||||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.RegexpExtract: lambda self, e: self.func(
|
||||
"REGEXP_EXTRACT",
|
||||
e.this,
|
||||
e.expression,
|
||||
e.args.get("position"),
|
||||
e.args.get("occurrence"),
|
||||
),
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
exp.Select: transforms.preprocess(
|
||||
[_unqualify_unnest, transforms.eliminate_distinct_on]
|
||||
),
|
||||
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||
|
@ -295,7 +324,6 @@ class BigQuery(Dialect):
|
|||
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
|
||||
if e.name == "IMMUTABLE"
|
||||
else "NOT DETERMINISTIC",
|
||||
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
|
||||
}
|
||||
|
||||
TYPE_MAPPING = {
|
||||
|
@ -315,6 +343,7 @@ class BigQuery(Dialect):
|
|||
exp.DataType.Type.TEXT: "STRING",
|
||||
exp.DataType.Type.TIMESTAMP: "DATETIME",
|
||||
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.VARBINARY: "BYTES",
|
||||
exp.DataType.Type.VARCHAR: "STRING",
|
||||
|
|
|
@ -21,8 +21,9 @@ def _lower_func(sql: str) -> str:
|
|||
|
||||
|
||||
class ClickHouse(Dialect):
|
||||
normalize_functions = None
|
||||
null_ordering = "nulls_are_last"
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
STRICT_STRING_CONCAT = True
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
|
||||
|
@ -163,11 +164,11 @@ class ClickHouse(Dialect):
|
|||
|
||||
return this
|
||||
|
||||
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
|
||||
def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
|
||||
return super()._parse_position(haystack_first=True)
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
|
||||
def _parse_cte(self) -> exp.Expression:
|
||||
def _parse_cte(self) -> exp.CTE:
|
||||
index = self._index
|
||||
try:
|
||||
# WITH <identifier> AS <subquery expression>
|
||||
|
@ -187,17 +188,19 @@ class ClickHouse(Dialect):
|
|||
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
|
||||
is_global = self._match(TokenType.GLOBAL) and self._prev
|
||||
kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
|
||||
|
||||
if kind_pre:
|
||||
kind = self._match_set(self.JOIN_KINDS) and self._prev
|
||||
side = self._match_set(self.JOIN_SIDES) and self._prev
|
||||
return is_global, side, kind
|
||||
|
||||
return (
|
||||
is_global,
|
||||
self._match_set(self.JOIN_SIDES) and self._prev,
|
||||
self._match_set(self.JOIN_KINDS) and self._prev,
|
||||
)
|
||||
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
|
||||
def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
|
||||
join = super()._parse_join(skip_join_token)
|
||||
|
||||
if join:
|
||||
|
@ -205,9 +208,14 @@ class ClickHouse(Dialect):
|
|||
return join
|
||||
|
||||
def _parse_function(
|
||||
self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
|
||||
self,
|
||||
functions: t.Optional[t.Dict[str, t.Callable]] = None,
|
||||
anonymous: bool = False,
|
||||
optional_parens: bool = True,
|
||||
) -> t.Optional[exp.Expression]:
|
||||
func = super()._parse_function(functions, anonymous)
|
||||
func = super()._parse_function(
|
||||
functions=functions, anonymous=anonymous, optional_parens=optional_parens
|
||||
)
|
||||
|
||||
if isinstance(func, exp.Anonymous):
|
||||
params = self._parse_func_params(func)
|
||||
|
@ -227,10 +235,12 @@ class ClickHouse(Dialect):
|
|||
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
|
||||
return self._parse_csv(self._parse_lambda)
|
||||
|
||||
if self._match(TokenType.L_PAREN):
|
||||
params = self._parse_csv(self._parse_lambda)
|
||||
self._match_r_paren(this)
|
||||
return params
|
||||
|
||||
return None
|
||||
|
||||
def _parse_quantile(self) -> exp.Quantile:
|
||||
|
@ -247,12 +257,12 @@ class ClickHouse(Dialect):
|
|||
|
||||
def _parse_primary_key(
|
||||
self, wrapped_optional: bool = False, in_props: bool = False
|
||||
) -> exp.Expression:
|
||||
) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
|
||||
return super()._parse_primary_key(
|
||||
wrapped_optional=wrapped_optional or in_props, in_props=in_props
|
||||
)
|
||||
|
||||
def _parse_on_property(self) -> t.Optional[exp.Property]:
|
||||
def _parse_on_property(self) -> t.Optional[exp.Expression]:
|
||||
index = self._index
|
||||
if self._match_text_seq("CLUSTER"):
|
||||
this = self._parse_id_var()
|
||||
|
@ -329,6 +339,16 @@ class ClickHouse(Dialect):
|
|||
"NAMED COLLECTION",
|
||||
}
|
||||
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
# Clickhouse errors out if we try to cast a NULL value to TEXT
|
||||
return self.func(
|
||||
"CONCAT",
|
||||
*[
|
||||
exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
|
||||
for e in expression.expressions
|
||||
],
|
||||
)
|
||||
|
||||
def cte_sql(self, expression: exp.CTE) -> str:
|
||||
if isinstance(expression.this, exp.Alias):
|
||||
return self.sql(expression, "this")
|
||||
|
|
|
@ -25,6 +25,8 @@ class Dialects(str, Enum):
|
|||
|
||||
BIGQUERY = "bigquery"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
DATABRICKS = "databricks"
|
||||
DRILL = "drill"
|
||||
DUCKDB = "duckdb"
|
||||
HIVE = "hive"
|
||||
MYSQL = "mysql"
|
||||
|
@ -38,11 +40,9 @@ class Dialects(str, Enum):
|
|||
SQLITE = "sqlite"
|
||||
STARROCKS = "starrocks"
|
||||
TABLEAU = "tableau"
|
||||
TERADATA = "teradata"
|
||||
TRINO = "trino"
|
||||
TSQL = "tsql"
|
||||
DATABRICKS = "databricks"
|
||||
DRILL = "drill"
|
||||
TERADATA = "teradata"
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
|
@ -76,16 +76,19 @@ class _Dialect(type):
|
|||
enum = Dialects.__members__.get(clsname.upper())
|
||||
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
|
||||
|
||||
klass.time_trie = new_trie(klass.time_mapping)
|
||||
klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
|
||||
klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
|
||||
klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
|
||||
klass.FORMAT_TRIE = (
|
||||
new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
|
||||
)
|
||||
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
|
||||
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
|
||||
|
||||
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
|
||||
klass.parser_class = getattr(klass, "Parser", Parser)
|
||||
klass.generator_class = getattr(klass, "Generator", Generator)
|
||||
|
||||
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||
klass.identifier_start, klass.identifier_end = list(
|
||||
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
|
||||
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
|
||||
klass.tokenizer_class._IDENTIFIERS.items()
|
||||
)[0]
|
||||
|
||||
|
@ -99,43 +102,80 @@ class _Dialect(type):
|
|||
(None, None),
|
||||
)
|
||||
|
||||
klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
|
||||
klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
|
||||
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
|
||||
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
|
||||
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
|
||||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING)
|
||||
|
||||
klass.tokenizer_class.identifiers_can_start_with_digit = (
|
||||
klass.identifiers_can_start_with_digit
|
||||
)
|
||||
dialect_properties = {
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(klass).items()
|
||||
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
|
||||
},
|
||||
"STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
|
||||
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||
}
|
||||
|
||||
# Pass required dialect properties to the tokenizer, parser and generator classes
|
||||
for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
|
||||
for name, value in dialect_properties.items():
|
||||
if hasattr(subclass, name):
|
||||
setattr(subclass, name, value)
|
||||
|
||||
if not klass.STRICT_STRING_CONCAT:
|
||||
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
class Dialect(metaclass=_Dialect):
|
||||
index_offset = 0
|
||||
unnest_column_only = False
|
||||
alias_post_tablesample = False
|
||||
identifiers_can_start_with_digit = False
|
||||
normalize_functions: t.Optional[str] = "upper"
|
||||
null_ordering = "nulls_are_small"
|
||||
# Determines the base index offset for arrays
|
||||
INDEX_OFFSET = 0
|
||||
|
||||
date_format = "'%Y-%m-%d'"
|
||||
dateint_format = "'%Y%m%d'"
|
||||
time_format = "'%Y-%m-%d %H:%M:%S'"
|
||||
time_mapping: t.Dict[str, str] = {}
|
||||
# If true unnest table aliases are considered only as column aliases
|
||||
UNNEST_COLUMN_ONLY = False
|
||||
|
||||
# autofilled
|
||||
quote_start = None
|
||||
quote_end = None
|
||||
identifier_start = None
|
||||
identifier_end = None
|
||||
# Determines whether or not the table alias comes after tablesample
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
|
||||
time_trie = None
|
||||
inverse_time_mapping = None
|
||||
inverse_time_trie = None
|
||||
tokenizer_class = None
|
||||
parser_class = None
|
||||
generator_class = None
|
||||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
|
||||
# Determines whether or not CONCAT's arguments must be strings
|
||||
STRICT_STRING_CONCAT = False
|
||||
|
||||
# Determines how function names are going to be normalized
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
|
||||
# Indicates the default null ordering method to use if not explicitly set
|
||||
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
DATE_FORMAT = "'%Y-%m-%d'"
|
||||
DATEINT_FORMAT = "'%Y%m%d'"
|
||||
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
|
||||
|
||||
# Custom time mappings in which the key represents dialect time format
|
||||
# and the value represents a python time format
|
||||
TIME_MAPPING: t.Dict[str, str] = {}
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
|
||||
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
|
||||
FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
|
||||
# Autofilled
|
||||
tokenizer_class = Tokenizer
|
||||
parser_class = Parser
|
||||
generator_class = Generator
|
||||
|
||||
# A trie of the time_mapping keys
|
||||
TIME_TRIE: t.Dict = {}
|
||||
FORMAT_TRIE: t.Dict = {}
|
||||
|
||||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
return type(self) == other
|
||||
|
@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect):
|
|||
) -> t.Optional[exp.Expression]:
|
||||
if isinstance(expression, str):
|
||||
return exp.Literal.string(
|
||||
format_time(
|
||||
expression[1:-1], # the time formats are quoted
|
||||
cls.time_mapping,
|
||||
cls.time_trie,
|
||||
)
|
||||
# the time formats are quoted
|
||||
format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
|
||||
)
|
||||
|
||||
if expression and expression.is_string:
|
||||
return exp.Literal.string(
|
||||
format_time(
|
||||
expression.this,
|
||||
cls.time_mapping,
|
||||
cls.time_trie,
|
||||
)
|
||||
)
|
||||
return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
|
||||
|
||||
return expression
|
||||
|
||||
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||
|
@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect):
|
|||
@property
|
||||
def tokenizer(self) -> Tokenizer:
|
||||
if not hasattr(self, "_tokenizer"):
|
||||
self._tokenizer = self.tokenizer_class() # type: ignore
|
||||
self._tokenizer = self.tokenizer_class()
|
||||
return self._tokenizer
|
||||
|
||||
def parser(self, **opts) -> Parser:
|
||||
return self.parser_class( # type: ignore
|
||||
**{
|
||||
"index_offset": self.index_offset,
|
||||
"unnest_column_only": self.unnest_column_only,
|
||||
"alias_post_tablesample": self.alias_post_tablesample,
|
||||
"null_ordering": self.null_ordering,
|
||||
**opts,
|
||||
},
|
||||
)
|
||||
return self.parser_class(**opts)
|
||||
|
||||
def generator(self, **opts) -> Generator:
|
||||
return self.generator_class( # type: ignore
|
||||
**{
|
||||
"quote_start": self.quote_start,
|
||||
"quote_end": self.quote_end,
|
||||
"bit_start": self.bit_start,
|
||||
"bit_end": self.bit_end,
|
||||
"hex_start": self.hex_start,
|
||||
"hex_end": self.hex_end,
|
||||
"byte_start": self.byte_start,
|
||||
"byte_end": self.byte_end,
|
||||
"raw_start": self.raw_start,
|
||||
"raw_end": self.raw_end,
|
||||
"identifier_start": self.identifier_start,
|
||||
"identifier_end": self.identifier_end,
|
||||
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||
"index_offset": self.index_offset,
|
||||
"time_mapping": self.inverse_time_mapping,
|
||||
"time_trie": self.inverse_time_trie,
|
||||
"unnest_column_only": self.unnest_column_only,
|
||||
"alias_post_tablesample": self.alias_post_tablesample,
|
||||
"identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
|
||||
"normalize_functions": self.normalize_functions,
|
||||
"null_ordering": self.null_ordering,
|
||||
**opts,
|
||||
}
|
||||
)
|
||||
return self.generator_class(**opts)
|
||||
|
||||
|
||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||
|
@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
|||
|
||||
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||
return self.like_sql(
|
||||
exp.Like(
|
||||
this=exp.Lower(this=expression.this),
|
||||
expression=expression.args["expression"],
|
||||
)
|
||||
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
|
||||
)
|
||||
|
||||
|
||||
|
@ -359,6 +355,7 @@ def var_map_sql(
|
|||
for key, value in zip(keys.expressions, values.expressions):
|
||||
args.append(self.sql(key))
|
||||
args.append(self.sql(value))
|
||||
|
||||
return self.func(map_func_name, *args)
|
||||
|
||||
|
||||
|
@ -381,7 +378,7 @@ def format_time_lambda(
|
|||
this=seq_get(args, 0),
|
||||
format=Dialect[dialect].format_time(
|
||||
seq_get(args, 1)
|
||||
or (Dialect[dialect].time_format if default is True else default or None)
|
||||
or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -437,9 +434,7 @@ def parse_date_delta_with_interval(
|
|||
expression = exp.Literal.number(expression.this)
|
||||
|
||||
return expression_class(
|
||||
this=args[0],
|
||||
expression=expression,
|
||||
unit=exp.Literal.string(interval.text("unit")),
|
||||
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
|
||||
)
|
||||
|
||||
return func
|
||||
|
@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
|
|||
|
||||
def locate_to_strposition(args: t.List) -> exp.Expression:
|
||||
return exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
|
||||
)
|
||||
|
||||
|
||||
|
@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
|
|||
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
_dialect = Dialect.get_or_raise(dialect)
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
|
||||
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
|
||||
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
return _ts_or_ds_to_date_sql
|
||||
|
||||
|
||||
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
|
||||
this, *rest_args = expression.expressions
|
||||
for arg in rest_args:
|
||||
this = exp.DPipe(this=this, expression=arg)
|
||||
|
||||
return self.sql(this)
|
||||
|
||||
|
||||
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
|
||||
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
|
||||
names = []
|
||||
|
|
|
@ -16,21 +16,10 @@ from sqlglot.dialects.dialect import (
|
|||
)
|
||||
|
||||
|
||||
def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||
|
||||
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||
|
||||
|
||||
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
unit = exp.Var(this=expression.text("unit").upper() or "DAY")
|
||||
unit = exp.var(expression.text("unit").upper() or "DAY")
|
||||
return (
|
||||
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||
)
|
||||
|
@ -41,19 +30,19 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Drill.date_format:
|
||||
if time_format == Drill.DATE_FORMAT:
|
||||
return f"CAST({this} AS DATE)"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
||||
class Drill(Dialect):
|
||||
normalize_functions = None
|
||||
null_ordering = "nulls_are_last"
|
||||
date_format = "'yyyy-MM-dd'"
|
||||
dateint_format = "'yyyyMMdd'"
|
||||
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||
NORMALIZE_FUNCTIONS: bool | str = False
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
DATE_FORMAT = "'yyyy-MM-dd'"
|
||||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
|
@ -93,6 +82,7 @@ class Drill(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -135,8 +125,8 @@ class Drill(Dialect):
|
|||
exp.DateAdd: _date_add_sql("ADD"),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateSub: _date_add_sql("SUB"),
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
|
||||
exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
|
||||
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
|
||||
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
|
||||
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
|
||||
|
@ -154,7 +144,7 @@ class Drill(Dialect):
|
|||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.TryCast: no_trycast_sql,
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
|
||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
|
||||
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
}
|
||||
|
|
|
@ -56,11 +56,7 @@ def _sort_array_reverse(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
def _parse_date_diff(args: t.List) -> exp.Expression:
|
||||
return exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
)
|
||||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
|
||||
|
@ -90,7 +86,7 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
|
|||
|
||||
|
||||
class DuckDB(Dialect):
|
||||
null_ordering = "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
|
@ -118,6 +114,8 @@ class DuckDB(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
|
||||
|
@ -127,10 +125,7 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF": _parse_date_diff,
|
||||
"EPOCH": exp.TimeToUnix.from_arg_list,
|
||||
"EPOCH_MS": lambda args: exp.UnixToTime(
|
||||
this=exp.Div(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Literal.number(1000),
|
||||
)
|
||||
this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
|
||||
),
|
||||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
|
@ -191,8 +186,8 @@ class DuckDB(Dialect):
|
|||
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: datestrtodate_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.IntDiv: lambda self, e: self.binary(e, "//"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
|
@ -242,11 +237,27 @@ class DuckDB(Dialect):
|
|||
|
||||
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
|
||||
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
**generator.Generator.PROPERTIES_LOCATION,
|
||||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def interval_sql(self, expression: exp.Interval) -> str:
|
||||
multiplier: t.Optional[int] = None
|
||||
unit = expression.text("unit").lower()
|
||||
|
||||
if unit.startswith("week"):
|
||||
multiplier = 7
|
||||
if unit.startswith("quarter"):
|
||||
multiplier = 90
|
||||
|
||||
if multiplier:
|
||||
return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
|
||||
|
||||
return super().interval_sql(expression)
|
||||
|
||||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
|
||||
) -> str:
|
||||
|
|
|
@ -80,12 +80,12 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
|
|||
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
|
||||
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
|
||||
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
|
||||
|
||||
return f"{diff_sql}{multiplier_sql}"
|
||||
|
||||
|
||||
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
|
||||
this = expression.this
|
||||
|
||||
if not this.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
|
@ -113,7 +113,7 @@ def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> st
|
|||
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
|
||||
return f"CAST({this} AS DATE)"
|
||||
|
||||
|
@ -121,7 +121,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
|
|||
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format not in (Hive.time_format, Hive.date_format):
|
||||
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
|
||||
return f"CAST({this} AS TIMESTAMP)"
|
||||
|
||||
|
@ -130,7 +130,7 @@ def _time_format(
|
|||
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
|
||||
) -> t.Optional[str]:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.time_format:
|
||||
if time_format == Hive.TIME_FORMAT:
|
||||
return None
|
||||
return time_format
|
||||
|
||||
|
@ -144,16 +144,16 @@ def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
|
|||
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Hive.time_format, Hive.date_format):
|
||||
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
return f"TO_DATE({this})"
|
||||
|
||||
|
||||
class Hive(Dialect):
|
||||
alias_post_tablesample = True
|
||||
identifiers_can_start_with_digit = True
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = True
|
||||
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"y": "%Y",
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
|
@ -184,9 +184,9 @@ class Hive(Dialect):
|
|||
"EEEE": "%A",
|
||||
}
|
||||
|
||||
date_format = "'yyyy-MM-dd'"
|
||||
dateint_format = "'yyyyMMdd'"
|
||||
time_format = "'yyyy-MM-dd HH:mm:ss'"
|
||||
DATE_FORMAT = "'yyyy-MM-dd'"
|
||||
DATEINT_FORMAT = "'yyyyMMdd'"
|
||||
TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
QUOTES = ["'", '"']
|
||||
|
@ -224,9 +224,7 @@ class Hive(Dialect):
|
|||
"BASE64": exp.ToBase64.from_arg_list,
|
||||
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
expression=seq_get(args, 1),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
|
@ -234,10 +232,7 @@ class Hive(Dialect):
|
|||
),
|
||||
"DATE_SUB": lambda args: exp.TsOrDsAdd(
|
||||
this=seq_get(args, 0),
|
||||
expression=exp.Mul(
|
||||
this=seq_get(args, 1),
|
||||
expression=exp.Literal.number(-1),
|
||||
),
|
||||
expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
|
||||
unit=exp.Literal.string("DAY"),
|
||||
),
|
||||
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
|
||||
|
@ -349,8 +344,8 @@ class Hive(Dialect):
|
|||
exp.DateDiff: _date_diff_sql,
|
||||
exp.DateStrToDate: rename_func("TO_DATE"),
|
||||
exp.DateSub: _add_date_sql,
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
|
||||
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
|
||||
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
|
||||
exp.FromBase64: rename_func("UNBASE64"),
|
||||
exp.If: if_sql,
|
||||
|
@ -415,10 +410,7 @@ class Hive(Dialect):
|
|||
)
|
||||
|
||||
def with_properties(self, properties: exp.Properties) -> str:
|
||||
return self.properties(
|
||||
properties,
|
||||
prefix=self.seg("TBLPROPERTIES"),
|
||||
)
|
||||
return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
|
||||
|
||||
def datatype_sql(self, expression: exp.DataType) -> str:
|
||||
if (
|
||||
|
|
|
@ -94,10 +94,10 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
|
|||
|
||||
|
||||
class MySQL(Dialect):
|
||||
time_format = "'%Y-%m-%d %T'"
|
||||
TIME_FORMAT = "'%Y-%m-%d %T'"
|
||||
|
||||
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"%M": "%B",
|
||||
"%c": "%-m",
|
||||
"%e": "%-d",
|
||||
|
@ -128,6 +128,7 @@ class MySQL(Dialect):
|
|||
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
|
||||
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
|
||||
"SEPARATOR": TokenType.SEPARATOR,
|
||||
"ENUM": TokenType.ENUM,
|
||||
"START": TokenType.BEGIN,
|
||||
"_ARMSCII8": TokenType.INTRODUCER,
|
||||
"_ASCII": TokenType.INTRODUCER,
|
||||
|
@ -279,6 +280,16 @@ class MySQL(Dialect):
|
|||
"SWAPS",
|
||||
}
|
||||
|
||||
TYPE_TOKENS = {
|
||||
*parser.Parser.TYPE_TOKENS,
|
||||
TokenType.SET,
|
||||
}
|
||||
|
||||
ENUM_TYPE_TOKENS = {
|
||||
*parser.Parser.ENUM_TYPE_TOKENS,
|
||||
TokenType.SET,
|
||||
}
|
||||
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
def _parse_show_mysql(
|
||||
|
@ -372,12 +383,7 @@ class MySQL(Dialect):
|
|||
else:
|
||||
collate = None
|
||||
|
||||
return self.expression(
|
||||
exp.SetItem,
|
||||
this=charset,
|
||||
collate=collate,
|
||||
kind="NAMES",
|
||||
)
|
||||
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
|
||||
|
||||
class Generator(generator.Generator):
|
||||
LOCKING_READS_SUPPORTED = True
|
||||
|
@ -472,9 +478,7 @@ class MySQL(Dialect):
|
|||
|
||||
def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
|
||||
sql = self.sql(expression, arg)
|
||||
if not sql:
|
||||
return ""
|
||||
return f" {prefix} {sql}"
|
||||
return f" {prefix} {sql}" if sql else ""
|
||||
|
||||
def _oldstyle_limit_sql(self, expression: exp.Show) -> str:
|
||||
limit = self.sql(expression, "limit")
|
||||
|
|
|
@ -24,21 +24,15 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
|
|||
if self._match_text_seq("COLUMNS"):
|
||||
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
|
||||
|
||||
return self.expression(
|
||||
exp.XMLTable,
|
||||
this=this,
|
||||
passing=passing,
|
||||
columns=columns,
|
||||
by_ref=by_ref,
|
||||
)
|
||||
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
|
||||
|
||||
|
||||
class Oracle(Dialect):
|
||||
alias_post_tablesample = True
|
||||
ALIAS_POST_TABLESAMPLE = True
|
||||
|
||||
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
|
||||
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"AM": "%p", # Meridian indicator with or without periods
|
||||
"A.M.": "%p", # Meridian indicator with or without periods
|
||||
"PM": "%p", # Meridian indicator with or without periods
|
||||
|
@ -87,7 +81,7 @@ class Oracle(Dialect):
|
|||
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
|
||||
return column
|
||||
|
||||
def _parse_hint(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_hint(self) -> t.Optional[exp.Hint]:
|
||||
if self._match(TokenType.HINT):
|
||||
start = self._curr
|
||||
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
|
||||
|
@ -129,7 +123,7 @@ class Oracle(Dialect):
|
|||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
|
||||
exp.ILike: no_ilike_sql,
|
||||
exp.IfNull: rename_func("NVL"),
|
||||
exp.Coalesce: rename_func("NVL"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
|
||||
|
@ -179,7 +173,6 @@ class Oracle(Dialect):
|
|||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||
"MINUS": TokenType.EXCEPT,
|
||||
"NVARCHAR2": TokenType.NVARCHAR,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"SAMPLE": TokenType.TABLE_SAMPLE,
|
||||
"START": TokenType.BEGIN,
|
||||
"TOP": TokenType.TOP,
|
||||
|
|
|
@ -183,9 +183,10 @@ def _to_timestamp(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
class Postgres(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
time_mapping = {
|
||||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
"AM": "%p",
|
||||
"PM": "%p",
|
||||
"D": "%u", # 1-based day of week
|
||||
|
@ -241,7 +242,6 @@ class Postgres(Dialect):
|
|||
"REFRESH": TokenType.COMMAND,
|
||||
"REINDEX": TokenType.COMMAND,
|
||||
"RESET": TokenType.COMMAND,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"REVOKE": TokenType.COMMAND,
|
||||
"SERIAL": TokenType.SERIAL,
|
||||
"SMALLSERIAL": TokenType.SMALLSERIAL,
|
||||
|
@ -258,6 +258,7 @@ class Postgres(Dialect):
|
|||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -268,6 +269,7 @@ class Postgres(Dialect):
|
|||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
|
||||
"TO_TIMESTAMP": _to_timestamp,
|
||||
"UNNEST": exp.Explode.from_arg_list,
|
||||
}
|
||||
|
||||
FUNCTION_PARSERS = {
|
||||
|
@ -303,7 +305,7 @@ class Postgres(Dialect):
|
|||
value = self._parse_bitwise()
|
||||
|
||||
if part and part.is_string:
|
||||
part = exp.Var(this=part.name)
|
||||
part = exp.var(part.name)
|
||||
|
||||
return self.expression(exp.Extract, this=part, expression=value)
|
||||
|
||||
|
@ -328,6 +330,7 @@ class Postgres(Dialect):
|
|||
**generator.Generator.TRANSFORMS,
|
||||
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
|
||||
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
|
||||
exp.Explode: rename_func("UNNEST"),
|
||||
exp.JSONExtract: arrow_json_extract_sql,
|
||||
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
|
||||
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
|
||||
|
|
|
@ -102,7 +102,7 @@ def _str_to_time_sql(
|
|||
|
||||
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||
time_format = self.format_time(expression)
|
||||
if time_format and time_format not in (Presto.time_format, Presto.date_format):
|
||||
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
|
||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
|
||||
|
||||
|
@ -119,7 +119,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
|
|||
exp.Literal.number(1),
|
||||
exp.Literal.number(10),
|
||||
),
|
||||
Presto.date_format,
|
||||
Presto.DATE_FORMAT,
|
||||
)
|
||||
|
||||
return self.func(
|
||||
|
@ -145,9 +145,7 @@ def _approx_percentile(args: t.List) -> exp.Expression:
|
|||
)
|
||||
if len(args) == 3:
|
||||
return exp.ApproxQuantile(
|
||||
this=seq_get(args, 0),
|
||||
quantile=seq_get(args, 1),
|
||||
accuracy=seq_get(args, 2),
|
||||
this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2)
|
||||
)
|
||||
return exp.ApproxQuantile.from_arg_list(args)
|
||||
|
||||
|
@ -160,10 +158,8 @@ def _from_unixtime(args: t.List) -> exp.Expression:
|
|||
minutes=seq_get(args, 2),
|
||||
)
|
||||
if len(args) == 2:
|
||||
return exp.UnixToTime(
|
||||
this=seq_get(args, 0),
|
||||
zone=seq_get(args, 1),
|
||||
)
|
||||
return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1))
|
||||
|
||||
return exp.UnixToTime.from_arg_list(args)
|
||||
|
||||
|
||||
|
@ -173,21 +169,17 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
|
|||
unnest = exp.Unnest(expressions=[expression.this])
|
||||
|
||||
if expression.alias:
|
||||
return exp.alias_(
|
||||
unnest,
|
||||
alias="_u",
|
||||
table=[expression.alias],
|
||||
copy=False,
|
||||
)
|
||||
return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
|
||||
return unnest
|
||||
return expression
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
index_offset = 1
|
||||
null_ordering = "nulls_are_last"
|
||||
time_format = MySQL.time_format
|
||||
time_mapping = MySQL.time_mapping
|
||||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
TIME_FORMAT = MySQL.TIME_FORMAT
|
||||
TIME_MAPPING = MySQL.TIME_MAPPING
|
||||
STRICT_STRING_CONCAT = True
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
KEYWORDS = {
|
||||
|
@ -205,14 +197,10 @@ class Presto(Dialect):
|
|||
"CARDINALITY": exp.ArraySize.from_arg_list,
|
||||
"CONTAINS": exp.ArrayContains.from_arg_list,
|
||||
"DATE_ADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_DIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
|
||||
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
|
||||
|
@ -225,9 +213,7 @@ class Presto(Dialect):
|
|||
"NOW": exp.CurrentTimestamp.from_arg_list,
|
||||
"SEQUENCE": exp.GenerateSeries.from_arg_list,
|
||||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0),
|
||||
substr=seq_get(args, 1),
|
||||
instance=seq_get(args, 2),
|
||||
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
|
||||
),
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"TO_HEX": exp.Hex.from_arg_list,
|
||||
|
@ -242,7 +228,7 @@ class Presto(Dialect):
|
|||
INTERVAL_ALLOWS_PLURAL_FORM = False
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
IS_BOOL = False
|
||||
IS_BOOL_ALLOWED = False
|
||||
STRUCT_DELIMITER = ("(", ")")
|
||||
|
||||
PROPERTIES_LOCATION = {
|
||||
|
@ -284,10 +270,10 @@ class Presto(Dialect):
|
|||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
),
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
|
||||
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
|
||||
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
|
||||
exp.Decode: _decode_sql,
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
|
||||
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
|
||||
exp.Encode: _encode_sql,
|
||||
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
|
||||
exp.Group: transforms.preprocess([transforms.unalias_group]),
|
||||
|
@ -322,7 +308,7 @@ class Presto(Dialect):
|
|||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: timestrtotime_sql,
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
|
||||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
|
||||
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
|
||||
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
|
@ -367,8 +353,16 @@ class Presto(Dialect):
|
|||
to = target_type.copy()
|
||||
|
||||
if target_type is start.to:
|
||||
end = exp.Cast(this=end, to=to)
|
||||
end = exp.cast(end, to)
|
||||
else:
|
||||
start = exp.Cast(this=start, to=to)
|
||||
start = exp.cast(start, to)
|
||||
|
||||
return self.func("SEQUENCE", start, end, step)
|
||||
|
||||
def offset_limit_modifiers(
|
||||
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
|
||||
) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "offset"),
|
||||
self.sql(limit),
|
||||
]
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import typing as t
|
||||
|
||||
from sqlglot import exp, transforms
|
||||
from sqlglot.dialects.dialect import rename_func
|
||||
from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func
|
||||
from sqlglot.dialects.postgres import Postgres
|
||||
from sqlglot.helper import seq_get
|
||||
from sqlglot.tokens import TokenType
|
||||
|
@ -14,9 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
|
|||
|
||||
|
||||
class Redshift(Postgres):
|
||||
time_format = "'YYYY-MM-DD HH:MI:SS'"
|
||||
time_mapping = {
|
||||
**Postgres.time_mapping,
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
|
||||
TIME_MAPPING = {
|
||||
**Postgres.TIME_MAPPING,
|
||||
"MON": "%b",
|
||||
"HH": "%H",
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ class Redshift(Postgres):
|
|||
and this.expressions
|
||||
and this.expressions[0].this == exp.column("MAX")
|
||||
):
|
||||
this.set("expressions", [exp.Var(this="MAX")])
|
||||
this.set("expressions", [exp.var("MAX")])
|
||||
|
||||
return this
|
||||
|
||||
|
@ -94,6 +94,7 @@ class Redshift(Postgres):
|
|||
|
||||
TRANSFORMS = {
|
||||
**Postgres.Generator.TRANSFORMS,
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
|
||||
|
@ -106,6 +107,7 @@ class Redshift(Postgres):
|
|||
exp.FromBase: rename_func("STRTOL"),
|
||||
exp.JSONExtract: _json_sql,
|
||||
exp.JSONExtractScalar: _json_sql,
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
|
||||
exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
|
||||
|
@ -170,6 +172,6 @@ class Redshift(Postgres):
|
|||
precision = expression.args.get("expressions")
|
||||
|
||||
if not precision:
|
||||
expression.append("expressions", exp.Var(this="MAX"))
|
||||
expression.append("expressions", exp.var("MAX"))
|
||||
|
||||
return super().datatype_sql(expression)
|
||||
|
|
|
@ -167,10 +167,10 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:
|
|||
|
||||
|
||||
class Snowflake(Dialect):
|
||||
null_ordering = "nulls_are_large"
|
||||
time_format = "'yyyy-mm-dd hh24:mi:ss'"
|
||||
NULL_ORDERING = "nulls_are_large"
|
||||
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
|
||||
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"YYYY": "%Y",
|
||||
"yyyy": "%Y",
|
||||
"YY": "%y",
|
||||
|
@ -210,14 +210,10 @@ class Snowflake(Dialect):
|
|||
"CONVERT_TIMEZONE": _parse_convert_timezone,
|
||||
"DATE_TRUNC": date_trunc_to_time,
|
||||
"DATEADD": lambda args: exp.DateAdd(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DATEDIFF": lambda args: exp.DateDiff(
|
||||
this=seq_get(args, 2),
|
||||
expression=seq_get(args, 1),
|
||||
unit=seq_get(args, 0),
|
||||
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
|
||||
),
|
||||
"DIV0": _div0_to_if,
|
||||
"IFF": exp.If.from_arg_list,
|
||||
|
@ -246,9 +242,7 @@ class Snowflake(Dialect):
|
|||
COLUMN_OPERATORS = {
|
||||
**parser.Parser.COLUMN_OPERATORS,
|
||||
TokenType.COLON: lambda self, this, path: self.expression(
|
||||
exp.Bracket,
|
||||
this=this,
|
||||
expressions=[path],
|
||||
exp.Bracket, this=this, expressions=[path]
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -275,6 +269,7 @@ class Snowflake(Dialect):
|
|||
QUOTES = ["'", "$$"]
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
|
||||
COMMENTS = ["--", "//", ("/*", "*/")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
|
|
|
@ -38,7 +38,7 @@ def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
|
|||
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
time_format = self.format_time(expression)
|
||||
if time_format == Hive.date_format:
|
||||
if time_format == Hive.DATE_FORMAT:
|
||||
return f"TO_DATE({this})"
|
||||
return f"TO_DATE({this}, {time_format})"
|
||||
|
||||
|
@ -133,13 +133,13 @@ class Spark2(Hive):
|
|||
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
|
||||
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
|
||||
),
|
||||
"DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
|
||||
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
|
||||
this=seq_get(args, 1),
|
||||
unit=exp.var(seq_get(args, 0)),
|
||||
),
|
||||
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
|
||||
"BOOLEAN": _parse_as_cast("boolean"),
|
||||
"DATE": _parse_as_cast("date"),
|
||||
"DOUBLE": _parse_as_cast("double"),
|
||||
"FLOAT": _parse_as_cast("float"),
|
||||
"INT": _parse_as_cast("int"),
|
||||
|
@ -162,11 +162,9 @@ class Spark2(Hive):
|
|||
def _parse_add_column(self) -> t.Optional[exp.Expression]:
|
||||
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
|
||||
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
|
||||
def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
|
||||
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
|
||||
exp.Drop,
|
||||
this=self._parse_schema(),
|
||||
kind="COLUMNS",
|
||||
exp.Drop, this=self._parse_schema(), kind="COLUMNS"
|
||||
)
|
||||
|
||||
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
|
||||
|
|
|
@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
|
|||
Dialect,
|
||||
arrow_json_extract_scalar_sql,
|
||||
arrow_json_extract_sql,
|
||||
concat_to_dpipe_sql,
|
||||
count_if_to_sum,
|
||||
no_ilike_sql,
|
||||
no_pivot_sql,
|
||||
|
@ -62,10 +63,6 @@ class SQLite(Dialect):
|
|||
IDENTIFIERS = ['"', ("[", "]"), "`"]
|
||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
|
@ -100,6 +97,7 @@ class SQLite(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Concat: concat_to_dpipe_sql,
|
||||
exp.CountIf: count_if_to_sum,
|
||||
exp.Create: transforms.preprocess([_transform_create]),
|
||||
exp.CurrentDate: lambda *_: "CURRENT_DATE",
|
||||
|
@ -116,6 +114,7 @@ class SQLite(Dialect):
|
|||
exp.LogicalOr: rename_func("MAX"),
|
||||
exp.LogicalAnd: rename_func("MIN"),
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.SafeConcat: concat_to_dpipe_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
|
||||
),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from sqlglot import exp, generator, parser, transforms
|
||||
from sqlglot.dialects.dialect import Dialect
|
||||
from sqlglot.dialects.dialect import Dialect, rename_func
|
||||
|
||||
|
||||
class Tableau(Dialect):
|
||||
|
@ -11,6 +11,7 @@ class Tableau(Dialect):
|
|||
|
||||
TRANSFORMS = {
|
||||
**generator.Generator.TRANSFORMS,
|
||||
exp.Coalesce: rename_func("IFNULL"),
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
}
|
||||
|
||||
|
@ -25,9 +26,6 @@ class Tableau(Dialect):
|
|||
false = self.sql(expression, "false")
|
||||
return f"IF {this} THEN {true} ELSE {false} END"
|
||||
|
||||
def coalesce_sql(self, expression: exp.Coalesce) -> str:
|
||||
return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
|
||||
|
||||
def count_sql(self, expression: exp.Count) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.Distinct):
|
||||
|
|
|
@ -1,18 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp, generator, parser, tokens, transforms
|
||||
from sqlglot.dialects.dialect import (
|
||||
Dialect,
|
||||
format_time_lambda,
|
||||
max_or_greatest,
|
||||
min_or_least,
|
||||
)
|
||||
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
|
||||
from sqlglot.tokens import TokenType
|
||||
|
||||
|
||||
class Teradata(Dialect):
|
||||
TIME_MAPPING = {
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
"YY": "%y",
|
||||
"MMMM": "%B",
|
||||
"MMM": "%b",
|
||||
"DD": "%d",
|
||||
"D": "%-d",
|
||||
"HH": "%H",
|
||||
"H": "%-H",
|
||||
"MM": "%M",
|
||||
"M": "%-M",
|
||||
"SS": "%S",
|
||||
"S": "%-S",
|
||||
"SSSSSS": "%f",
|
||||
"E": "%a",
|
||||
"EE": "%a",
|
||||
"EEE": "%a",
|
||||
"EEEE": "%A",
|
||||
}
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
|
||||
KEYWORDS = {
|
||||
|
@ -31,7 +45,7 @@ class Teradata(Dialect):
|
|||
"ST_GEOMETRY": TokenType.GEOMETRY,
|
||||
}
|
||||
|
||||
# teradata does not support % for modulus
|
||||
# Teradata does not support % as a modulo operator
|
||||
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
|
||||
SINGLE_TOKENS.pop("%")
|
||||
|
||||
|
@ -101,7 +115,7 @@ class Teradata(Dialect):
|
|||
|
||||
# FROM before SET in Teradata UPDATE syntax
|
||||
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
|
||||
def _parse_update(self) -> exp.Expression:
|
||||
def _parse_update(self) -> exp.Update:
|
||||
return self.expression(
|
||||
exp.Update,
|
||||
**{ # type: ignore
|
||||
|
@ -122,14 +136,6 @@ class Teradata(Dialect):
|
|||
|
||||
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
|
||||
|
||||
def _parse_cast(self, strict: bool) -> exp.Expression:
|
||||
cast = t.cast(exp.Cast, super()._parse_cast(strict))
|
||||
if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
|
||||
return format_time_lambda(exp.TimeToStr, "teradata")(
|
||||
[cast.this, self._parse_string()]
|
||||
)
|
||||
return cast
|
||||
|
||||
class Generator(generator.Generator):
|
||||
JOIN_HINTS = False
|
||||
TABLE_HINTS = False
|
||||
|
@ -151,7 +157,7 @@ class Teradata(Dialect):
|
|||
exp.Max: max_or_greatest,
|
||||
exp.Min: min_or_least,
|
||||
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
|
||||
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
}
|
||||
|
||||
|
|
|
@ -64,9 +64,9 @@ def _format_time_lambda(
|
|||
format=exp.Literal.string(
|
||||
format_time(
|
||||
args[0].name,
|
||||
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
|
||||
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
|
||||
if full_format_mapping
|
||||
else TSQL.time_mapping,
|
||||
else TSQL.TIME_MAPPING,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
@ -86,9 +86,9 @@ def _parse_format(args: t.List) -> exp.Expression:
|
|||
return exp.TimeToStr(
|
||||
this=args[0],
|
||||
format=exp.Literal.string(
|
||||
format_time(fmt.name, TSQL.format_time_mapping)
|
||||
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
|
||||
if len(fmt.name) == 1
|
||||
else format_time(fmt.name, TSQL.time_mapping)
|
||||
else format_time(fmt.name, TSQL.TIME_MAPPING)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -138,7 +138,7 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
|
|||
if isinstance(expression, exp.NumberToStr)
|
||||
else exp.Literal.string(
|
||||
format_time(
|
||||
expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
|
||||
expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -166,10 +166,10 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
|
|||
|
||||
|
||||
class TSQL(Dialect):
|
||||
null_ordering = "nulls_are_small"
|
||||
time_format = "'yyyy-mm-dd hh:mm:ss'"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
|
||||
|
||||
time_mapping = {
|
||||
TIME_MAPPING = {
|
||||
"year": "%Y",
|
||||
"qq": "%q",
|
||||
"q": "%q",
|
||||
|
@ -213,7 +213,7 @@ class TSQL(Dialect):
|
|||
"yy": "%y",
|
||||
}
|
||||
|
||||
convert_format_mapping = {
|
||||
CONVERT_FORMAT_MAPPING = {
|
||||
"0": "%b %d %Y %-I:%M%p",
|
||||
"1": "%m/%d/%y",
|
||||
"2": "%y.%m.%d",
|
||||
|
@ -253,8 +253,8 @@ class TSQL(Dialect):
|
|||
"120": "%Y-%m-%d %H:%M:%S",
|
||||
"121": "%Y-%m-%d %H:%M:%S.%f",
|
||||
}
|
||||
# not sure if complete
|
||||
format_time_mapping = {
|
||||
|
||||
FORMAT_TIME_MAPPING = {
|
||||
"y": "%B %Y",
|
||||
"d": "%m/%d/%Y",
|
||||
"H": "%-H",
|
||||
|
@ -312,9 +312,7 @@ class TSQL(Dialect):
|
|||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"CHARINDEX": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 1),
|
||||
substr=seq_get(args, 0),
|
||||
position=seq_get(args, 2),
|
||||
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
|
||||
),
|
||||
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
|
||||
|
@ -363,6 +361,8 @@ class TSQL(Dialect):
|
|||
LOG_BASE_FIRST = False
|
||||
LOG_DEFAULTS_TO_LN = True
|
||||
|
||||
CONCAT_NULL_OUTPUTS_STRING = True
|
||||
|
||||
def _parse_system_time(self) -> t.Optional[exp.Expression]:
|
||||
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
|
||||
return None
|
||||
|
@ -400,7 +400,7 @@ class TSQL(Dialect):
|
|||
table.set("system_time", self._parse_system_time())
|
||||
return table
|
||||
|
||||
def _parse_returns(self) -> exp.Expression:
|
||||
def _parse_returns(self) -> exp.ReturnsProperty:
|
||||
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
|
||||
returns = super()._parse_returns()
|
||||
returns.set("table", table)
|
||||
|
@ -423,12 +423,12 @@ class TSQL(Dialect):
|
|||
format_val = self._parse_number()
|
||||
format_val_name = format_val.name if format_val else ""
|
||||
|
||||
if format_val_name not in TSQL.convert_format_mapping:
|
||||
if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
|
||||
raise ValueError(
|
||||
f"CONVERT function at T-SQL does not support format style {format_val_name}"
|
||||
)
|
||||
|
||||
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
|
||||
format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
|
||||
|
||||
# Check whether the convert entails a string to date format
|
||||
if to.this == DataType.Type.DATE:
|
||||
|
|
|
@ -151,6 +151,7 @@ ENV = {
|
|||
"CAST": cast,
|
||||
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
|
||||
"CONCAT": null_if_any(lambda *args: "".join(args)),
|
||||
"SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
|
||||
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
|
||||
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
|
||||
"DIV": null_if_any(lambda e, this: e / this),
|
||||
|
@ -159,7 +160,6 @@ ENV = {
|
|||
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
|
||||
"GT": null_if_any(lambda this, e: this > e),
|
||||
"GTE": null_if_any(lambda this, e: this >= e),
|
||||
"IFNULL": lambda e, alt: alt if e is None else e,
|
||||
"IF": lambda predicate, true, false: true if predicate else false,
|
||||
"INTDIV": null_if_any(lambda e, this: e // this),
|
||||
"INTERVAL": interval,
|
||||
|
|
|
@ -394,7 +394,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
|
|||
names = {e.name.lower() for e in e.expressions}
|
||||
|
||||
e = e.transform(
|
||||
lambda n: exp.Var(this=n.name)
|
||||
lambda n: exp.var(n.name)
|
||||
if isinstance(n, exp.Identifier) and n.name.lower() in names
|
||||
else n
|
||||
)
|
||||
|
|
|
@ -1500,6 +1500,7 @@ class Index(Expression):
|
|||
arg_types = {
|
||||
"this": False,
|
||||
"table": False,
|
||||
"using": False,
|
||||
"where": False,
|
||||
"columns": False,
|
||||
"unique": False,
|
||||
|
@ -1623,7 +1624,7 @@ class Lambda(Expression):
|
|||
|
||||
|
||||
class Limit(Expression):
|
||||
arg_types = {"this": False, "expression": True}
|
||||
arg_types = {"this": False, "expression": True, "offset": False}
|
||||
|
||||
|
||||
class Literal(Condition):
|
||||
|
@ -1869,6 +1870,10 @@ class EngineProperty(Property):
|
|||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class ToTableProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
||||
class ExecuteAsProperty(Property):
|
||||
arg_types = {"this": True}
|
||||
|
||||
|
@ -3072,12 +3077,35 @@ class Select(Subqueryable):
|
|||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst.set("locks", [Lock(update=update)])
|
||||
|
||||
return inst
|
||||
|
||||
def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select:
|
||||
"""
|
||||
Set hints for this expression.
|
||||
|
||||
Examples:
|
||||
>>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark")
|
||||
'SELECT /*+ BROADCAST(y) */ x FROM tbl'
|
||||
|
||||
Args:
|
||||
hints: The SQL code strings to parse as the hints.
|
||||
If an `Expression` instance is passed, it will be used as-is.
|
||||
dialect: The dialect used to parse the hints.
|
||||
copy: If `False`, modify this expression instance in-place.
|
||||
|
||||
Returns:
|
||||
The modified expression.
|
||||
"""
|
||||
inst = _maybe_copy(self, copy)
|
||||
inst.set(
|
||||
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
|
||||
)
|
||||
|
||||
return inst
|
||||
|
||||
@property
|
||||
def named_selects(self) -> t.List[str]:
|
||||
return [e.output_name for e in self.expressions if e.alias_or_name]
|
||||
|
@ -3244,6 +3272,7 @@ class DataType(Expression):
|
|||
DATE = auto()
|
||||
DATETIME = auto()
|
||||
DATETIME64 = auto()
|
||||
ENUM = auto()
|
||||
INT4RANGE = auto()
|
||||
INT4MULTIRANGE = auto()
|
||||
INT8RANGE = auto()
|
||||
|
@ -3284,6 +3313,7 @@ class DataType(Expression):
|
|||
OBJECT = auto()
|
||||
ROWVERSION = auto()
|
||||
SERIAL = auto()
|
||||
SET = auto()
|
||||
SMALLINT = auto()
|
||||
SMALLMONEY = auto()
|
||||
SMALLSERIAL = auto()
|
||||
|
@ -3334,6 +3364,7 @@ class DataType(Expression):
|
|||
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
|
||||
|
||||
TEMPORAL_TYPES = {
|
||||
Type.TIME,
|
||||
Type.TIMESTAMP,
|
||||
Type.TIMESTAMPTZ,
|
||||
Type.TIMESTAMPLTZ,
|
||||
|
@ -3342,6 +3373,8 @@ class DataType(Expression):
|
|||
Type.DATETIME64,
|
||||
}
|
||||
|
||||
META_TYPES = {"UNKNOWN", "NULL"}
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
|
||||
|
@ -3349,8 +3382,9 @@ class DataType(Expression):
|
|||
from sqlglot import parse_one
|
||||
|
||||
if isinstance(dtype, str):
|
||||
if dtype.upper() in cls.Type.__members__:
|
||||
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
|
||||
upper = dtype.upper()
|
||||
if upper in DataType.META_TYPES:
|
||||
data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
|
||||
else:
|
||||
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
|
||||
|
||||
|
@ -3483,6 +3517,10 @@ class Dot(Binary):
|
|||
def name(self) -> str:
|
||||
return self.expression.name
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.name
|
||||
|
||||
@classmethod
|
||||
def build(self, expressions: t.Sequence[Expression]) -> Dot:
|
||||
"""Build a Dot object with a sequence of expressions."""
|
||||
|
@ -3502,6 +3540,10 @@ class DPipe(Binary):
|
|||
pass
|
||||
|
||||
|
||||
class SafeDPipe(DPipe):
|
||||
pass
|
||||
|
||||
|
||||
class EQ(Binary, Predicate):
|
||||
pass
|
||||
|
||||
|
@ -3615,6 +3657,10 @@ class Not(Unary):
|
|||
class Paren(Unary):
|
||||
arg_types = {"this": True, "with": False}
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self.this.name
|
||||
|
||||
|
||||
class Neg(Unary):
|
||||
pass
|
||||
|
@ -3904,6 +3950,7 @@ class Ceil(Func):
|
|||
class Coalesce(Func):
|
||||
arg_types = {"this": True, "expressions": False}
|
||||
is_var_len_args = True
|
||||
_sql_names = ["COALESCE", "IFNULL", "NVL"]
|
||||
|
||||
|
||||
class Concat(Func):
|
||||
|
@ -3911,12 +3958,17 @@ class Concat(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
class SafeConcat(Concat):
|
||||
pass
|
||||
|
||||
|
||||
class ConcatWs(Concat):
|
||||
_sql_names = ["CONCAT_WS"]
|
||||
|
||||
|
||||
class Count(AggFunc):
|
||||
arg_types = {"this": False}
|
||||
arg_types = {"this": False, "expressions": False}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class CountIf(AggFunc):
|
||||
|
@ -4049,6 +4101,11 @@ class DateToDi(Func):
|
|||
pass
|
||||
|
||||
|
||||
class Date(Func):
|
||||
arg_types = {"expressions": True}
|
||||
is_var_len_args = True
|
||||
|
||||
|
||||
class Day(Func):
|
||||
pass
|
||||
|
||||
|
@ -4102,11 +4159,6 @@ class If(Func):
|
|||
arg_types = {"this": True, "true": True, "false": False}
|
||||
|
||||
|
||||
class IfNull(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
_sql_names = ["IFNULL", "NVL"]
|
||||
|
||||
|
||||
class Initcap(Func):
|
||||
arg_types = {"this": True, "expression": False}
|
||||
|
||||
|
@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
|
|||
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
|
||||
|
||||
|
||||
def column_table_names(expression: Expression) -> t.List[str]:
|
||||
def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
|
||||
"""
|
||||
Return all table names referenced through columns in an expression.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
>>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))
|
||||
['c', 'a']
|
||||
>>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")))
|
||||
['a', 'c']
|
||||
|
||||
Args:
|
||||
expression: expression to find table names.
|
||||
exclude: a table name to exclude
|
||||
|
||||
Returns:
|
||||
A list of unique names.
|
||||
"""
|
||||
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
|
||||
return {
|
||||
table
|
||||
for table in (column.table for column in expression.find_all(Column))
|
||||
if table and table != exclude
|
||||
}
|
||||
|
||||
|
||||
def table_name(table: Table | str) -> str:
|
||||
|
@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str:
|
|||
return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part)
|
||||
|
||||
|
||||
def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
|
||||
def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
|
||||
"""Replace all tables in expression according to the mapping.
|
||||
|
||||
Args:
|
||||
expression: expression node to be transformed and replaced.
|
||||
mapping: mapping of table names.
|
||||
copy: whether or not to copy the expression.
|
||||
|
||||
Examples:
|
||||
>>> from sqlglot import exp, parse_one
|
||||
|
@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
|
|||
)
|
||||
return node
|
||||
|
||||
return expression.transform(_replace_tables)
|
||||
return expression.transform(_replace_tables, copy=copy)
|
||||
|
||||
|
||||
def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
|
||||
|
|
|
@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot")
|
|||
|
||||
class Generator:
|
||||
"""
|
||||
Generator interprets the given syntax tree and produces a SQL string as an output.
|
||||
Generator converts a given syntax tree to the corresponding SQL string.
|
||||
|
||||
Args:
|
||||
time_mapping (dict): the dictionary of custom time mappings in which the key
|
||||
represents a python time format and the output the target time format
|
||||
time_trie (trie): a trie of the time_mapping keys
|
||||
pretty (bool): if set to True the returned string will be formatted. Default: False.
|
||||
quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
|
||||
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
|
||||
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
|
||||
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
|
||||
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
|
||||
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
|
||||
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
|
||||
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
|
||||
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
|
||||
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
|
||||
raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
|
||||
raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
|
||||
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
|
||||
normalize (bool): if set to True all identifiers will lower cased
|
||||
string_escape (str): specifies a string escape character. Default: '.
|
||||
identifier_escape (str): specifies an identifier escape character. Default: ".
|
||||
pad (int): determines padding in a formatted string. Default: 2.
|
||||
indent (int): determines the size of indentation in a formatted string. Default: 4.
|
||||
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
|
||||
normalize_functions (str): normalize function names, "upper", "lower", or None
|
||||
Default: "upper"
|
||||
alias_post_tablesample (bool): if the table alias comes after tablesample
|
||||
Default: False
|
||||
identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
|
||||
Default: False
|
||||
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
|
||||
unsupported expressions. Default ErrorLevel.WARN.
|
||||
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
|
||||
Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
|
||||
Default: "nulls_are_small"
|
||||
max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
|
||||
pretty: Whether or not to format the produced SQL string.
|
||||
Default: False.
|
||||
identify: Determines when an identifier should be quoted. Possible values are:
|
||||
False (default): Never quote, except in cases where it's mandatory by the dialect.
|
||||
True or 'always': Always quote.
|
||||
'safe': Only quote identifiers that are case insensitive.
|
||||
normalize: Whether or not to normalize identifiers to lowercase.
|
||||
Default: False.
|
||||
pad: Determines the pad size in a formatted string.
|
||||
Default: 2.
|
||||
indent: Determines the indentation size in a formatted string.
|
||||
Default: 2.
|
||||
normalize_functions: Whether or not to normalize all function names. Possible values are:
|
||||
"upper" or True (default): Convert names to uppercase.
|
||||
"lower": Convert names to lowercase.
|
||||
False: Disables function name normalization.
|
||||
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
|
||||
Default ErrorLevel.WARN.
|
||||
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
|
||||
This is only relevant if unsupported_level is ErrorLevel.RAISE.
|
||||
Default: 3
|
||||
leading_comma (bool): if the the comma is leading or trailing in select statements
|
||||
leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
|
||||
This is only relevant when generating in pretty mode.
|
||||
Default: False
|
||||
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
|
||||
The default is on the smaller end because the length only represents a segment and not the true
|
||||
|
@ -86,6 +71,7 @@ class Generator:
|
|||
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
|
||||
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
|
||||
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
|
||||
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
|
||||
exp.TransientProperty: lambda self, e: "TRANSIENT",
|
||||
exp.StabilityProperty: lambda self, e: e.name,
|
||||
exp.VolatileProperty: lambda self, e: "VOLATILE",
|
||||
|
@ -138,15 +124,24 @@ class Generator:
|
|||
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
|
||||
LIMIT_FETCH = "ALL"
|
||||
|
||||
# Whether a table is allowed to be renamed with a db
|
||||
# Whether or not a table is allowed to be renamed with a db
|
||||
RENAME_TABLE_WITH_DB = True
|
||||
|
||||
# The separator for grouping sets and rollups
|
||||
GROUPINGS_SEP = ","
|
||||
|
||||
# The string used for creating index on a table
|
||||
# The string used for creating an index on a table
|
||||
INDEX_ON = "ON"
|
||||
|
||||
# Whether or not join hints should be generated
|
||||
JOIN_HINTS = True
|
||||
|
||||
# Whether or not table hints should be generated
|
||||
TABLE_HINTS = True
|
||||
|
||||
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
|
||||
IS_BOOL_ALLOWED = True
|
||||
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.NCHAR: "CHAR",
|
||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||
|
@ -228,6 +223,7 @@ class Generator:
|
|||
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
|
||||
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
|
||||
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
|
||||
|
@ -235,128 +231,110 @@ class Generator:
|
|||
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
|
||||
}
|
||||
|
||||
JOIN_HINTS = True
|
||||
TABLE_HINTS = True
|
||||
IS_BOOL = True
|
||||
|
||||
# Keywords that can't be used as unquoted identifier names
|
||||
RESERVED_KEYWORDS: t.Set[str] = set()
|
||||
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
|
||||
UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
|
||||
|
||||
# Expressions whose comments are separated from them for better formatting
|
||||
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Select,
|
||||
exp.From,
|
||||
exp.Where,
|
||||
exp.With,
|
||||
)
|
||||
|
||||
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
|
||||
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
|
||||
exp.Column,
|
||||
exp.Literal,
|
||||
exp.Neg,
|
||||
exp.Paren,
|
||||
)
|
||||
|
||||
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
|
||||
|
||||
# Autofilled
|
||||
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
|
||||
INVERSE_TIME_TRIE: t.Dict = {}
|
||||
INDEX_OFFSET = 0
|
||||
UNNEST_COLUMN_ONLY = False
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
STRICT_STRING_CONCAT = False
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
|
||||
# Delimiters for quotes, identifiers and the corresponding escape characters
|
||||
QUOTE_START = "'"
|
||||
QUOTE_END = "'"
|
||||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
STRING_ESCAPE = "'"
|
||||
IDENTIFIER_ESCAPE = '"'
|
||||
|
||||
# Delimiters for bit, hex, byte and raw literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
BIT_END: t.Optional[str] = None
|
||||
HEX_START: t.Optional[str] = None
|
||||
HEX_END: t.Optional[str] = None
|
||||
BYTE_START: t.Optional[str] = None
|
||||
BYTE_END: t.Optional[str] = None
|
||||
RAW_START: t.Optional[str] = None
|
||||
RAW_END: t.Optional[str] = None
|
||||
|
||||
__slots__ = (
|
||||
"time_mapping",
|
||||
"time_trie",
|
||||
"pretty",
|
||||
"quote_start",
|
||||
"quote_end",
|
||||
"identifier_start",
|
||||
"identifier_end",
|
||||
"bit_start",
|
||||
"bit_end",
|
||||
"hex_start",
|
||||
"hex_end",
|
||||
"byte_start",
|
||||
"byte_end",
|
||||
"raw_start",
|
||||
"raw_end",
|
||||
"identify",
|
||||
"normalize",
|
||||
"string_escape",
|
||||
"identifier_escape",
|
||||
"pad",
|
||||
"index_offset",
|
||||
"unnest_column_only",
|
||||
"alias_post_tablesample",
|
||||
"identifiers_can_start_with_digit",
|
||||
"_indent",
|
||||
"normalize_functions",
|
||||
"unsupported_level",
|
||||
"unsupported_messages",
|
||||
"null_ordering",
|
||||
"max_unsupported",
|
||||
"_indent",
|
||||
"leading_comma",
|
||||
"max_text_width",
|
||||
"comments",
|
||||
"unsupported_messages",
|
||||
"_escaped_quote_end",
|
||||
"_escaped_identifier_end",
|
||||
"_leading_comma",
|
||||
"_max_text_width",
|
||||
"_comments",
|
||||
"_cache",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_mapping=None,
|
||||
time_trie=None,
|
||||
pretty=None,
|
||||
quote_start=None,
|
||||
quote_end=None,
|
||||
identifier_start=None,
|
||||
identifier_end=None,
|
||||
bit_start=None,
|
||||
bit_end=None,
|
||||
hex_start=None,
|
||||
hex_end=None,
|
||||
byte_start=None,
|
||||
byte_end=None,
|
||||
raw_start=None,
|
||||
raw_end=None,
|
||||
identify=False,
|
||||
normalize=False,
|
||||
string_escape=None,
|
||||
identifier_escape=None,
|
||||
pad=2,
|
||||
indent=2,
|
||||
index_offset=0,
|
||||
unnest_column_only=False,
|
||||
alias_post_tablesample=False,
|
||||
identifiers_can_start_with_digit=False,
|
||||
normalize_functions="upper",
|
||||
unsupported_level=ErrorLevel.WARN,
|
||||
null_ordering=None,
|
||||
max_unsupported=3,
|
||||
leading_comma=False,
|
||||
max_text_width=80,
|
||||
comments=True,
|
||||
pretty: t.Optional[bool] = None,
|
||||
identify: str | bool = False,
|
||||
normalize: bool = False,
|
||||
pad: int = 2,
|
||||
indent: int = 2,
|
||||
normalize_functions: t.Optional[str | bool] = None,
|
||||
unsupported_level: ErrorLevel = ErrorLevel.WARN,
|
||||
max_unsupported: int = 3,
|
||||
leading_comma: bool = False,
|
||||
max_text_width: int = 80,
|
||||
comments: bool = True,
|
||||
):
|
||||
import sqlglot
|
||||
|
||||
self.time_mapping = time_mapping or {}
|
||||
self.time_trie = time_trie
|
||||
self.pretty = pretty if pretty is not None else sqlglot.pretty
|
||||
self.quote_start = quote_start or "'"
|
||||
self.quote_end = quote_end or "'"
|
||||
self.identifier_start = identifier_start or '"'
|
||||
self.identifier_end = identifier_end or '"'
|
||||
self.bit_start = bit_start
|
||||
self.bit_end = bit_end
|
||||
self.hex_start = hex_start
|
||||
self.hex_end = hex_end
|
||||
self.byte_start = byte_start
|
||||
self.byte_end = byte_end
|
||||
self.raw_start = raw_start
|
||||
self.raw_end = raw_end
|
||||
self.identify = identify
|
||||
self.normalize = normalize
|
||||
self.string_escape = string_escape or "'"
|
||||
self.identifier_escape = identifier_escape or '"'
|
||||
self.pad = pad
|
||||
self.index_offset = index_offset
|
||||
self.unnest_column_only = unnest_column_only
|
||||
self.alias_post_tablesample = alias_post_tablesample
|
||||
self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
|
||||
self.normalize_functions = normalize_functions
|
||||
self.unsupported_level = unsupported_level
|
||||
self.unsupported_messages = []
|
||||
self.max_unsupported = max_unsupported
|
||||
self.null_ordering = null_ordering
|
||||
self._indent = indent
|
||||
self._escaped_quote_end = self.string_escape + self.quote_end
|
||||
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
|
||||
self._leading_comma = leading_comma
|
||||
self._max_text_width = max_text_width
|
||||
self._comments = comments
|
||||
self._cache = None
|
||||
self.unsupported_level = unsupported_level
|
||||
self.max_unsupported = max_unsupported
|
||||
self.leading_comma = leading_comma
|
||||
self.max_text_width = max_text_width
|
||||
self.comments = comments
|
||||
|
||||
# This is both a Dialect property and a Generator argument, so we prioritize the latter
|
||||
self.normalize_functions = (
|
||||
self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
|
||||
)
|
||||
|
||||
self.unsupported_messages: t.List[str] = []
|
||||
self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
|
||||
self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
|
||||
self._cache: t.Optional[t.Dict[int, str]] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -364,17 +342,19 @@ class Generator:
|
|||
cache: t.Optional[t.Dict[int, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a SQL string by interpreting the given syntax tree.
|
||||
Generates the SQL string corresponding to the given syntax tree.
|
||||
|
||||
Args
|
||||
expression: the syntax tree.
|
||||
cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
|
||||
Args:
|
||||
expression: The syntax tree.
|
||||
cache: An optional sql string cache. This leverages the hash of an Expression
|
||||
which can be slow to compute, so only use it if you set _hash on each node.
|
||||
|
||||
Returns
|
||||
the SQL string.
|
||||
Returns:
|
||||
The SQL string corresponding to `expression`.
|
||||
"""
|
||||
if cache is not None:
|
||||
self._cache = cache
|
||||
|
||||
self.unsupported_messages = []
|
||||
sql = self.sql(expression).strip()
|
||||
self._cache = None
|
||||
|
@ -414,7 +394,11 @@ class Generator:
|
|||
expression: t.Optional[exp.Expression] = None,
|
||||
comments: t.Optional[t.List[str]] = None,
|
||||
) -> str:
|
||||
comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
|
||||
comments = (
|
||||
((expression and expression.comments) if comments is None else comments) # type: ignore
|
||||
if self.comments
|
||||
else None
|
||||
)
|
||||
|
||||
if not comments or isinstance(expression, exp.Binary):
|
||||
return sql
|
||||
|
@ -454,7 +438,7 @@ class Generator:
|
|||
return result
|
||||
|
||||
def normalize_func(self, name: str) -> str:
|
||||
if self.normalize_functions == "upper":
|
||||
if self.normalize_functions == "upper" or self.normalize_functions is True:
|
||||
return name.upper()
|
||||
if self.normalize_functions == "lower":
|
||||
return name.lower()
|
||||
|
@ -522,7 +506,7 @@ class Generator:
|
|||
else:
|
||||
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
|
||||
|
||||
sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
|
||||
sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
|
||||
|
||||
if self._cache is not None:
|
||||
self._cache[expression_id] = sql
|
||||
|
@ -770,25 +754,25 @@ class Generator:
|
|||
|
||||
def bitstring_sql(self, expression: exp.BitString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.bit_start:
|
||||
return f"{self.bit_start}{this}{self.bit_end}"
|
||||
if self.BIT_START:
|
||||
return f"{self.BIT_START}{this}{self.BIT_END}"
|
||||
return f"{int(this, 2)}"
|
||||
|
||||
def hexstring_sql(self, expression: exp.HexString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.hex_start:
|
||||
return f"{self.hex_start}{this}{self.hex_end}"
|
||||
if self.HEX_START:
|
||||
return f"{self.HEX_START}{this}{self.HEX_END}"
|
||||
return f"{int(this, 16)}"
|
||||
|
||||
def bytestring_sql(self, expression: exp.ByteString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.byte_start:
|
||||
return f"{self.byte_start}{this}{self.byte_end}"
|
||||
if self.BYTE_START:
|
||||
return f"{self.BYTE_START}{this}{self.BYTE_END}"
|
||||
return this
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
if self.raw_start:
|
||||
return f"{self.raw_start}{expression.name}{self.raw_end}"
|
||||
if self.RAW_START:
|
||||
return f"{self.RAW_START}{expression.name}{self.RAW_END}"
|
||||
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
|
||||
|
||||
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
|
||||
|
@ -883,24 +867,27 @@ class Generator:
|
|||
name = f"{expression.name} " if expression.name else ""
|
||||
table = self.sql(expression, "table")
|
||||
table = f"{self.INDEX_ON} {table} " if table else ""
|
||||
using = self.sql(expression, "using")
|
||||
using = f"USING {using} " if using else ""
|
||||
index = "INDEX " if not table else ""
|
||||
columns = self.expressions(expression, key="columns", flat=True)
|
||||
columns = f"({columns})" if columns else ""
|
||||
partition_by = self.expressions(expression, key="partition_by", flat=True)
|
||||
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
|
||||
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
|
||||
|
||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||
text = expression.name
|
||||
lower = text.lower()
|
||||
text = lower if self.normalize and not expression.quoted else text
|
||||
text = text.replace(self.identifier_end, self._escaped_identifier_end)
|
||||
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
|
||||
if (
|
||||
expression.quoted
|
||||
or should_identify(text, self.identify)
|
||||
or lower in self.RESERVED_KEYWORDS
|
||||
or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
|
||||
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
|
||||
):
|
||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||
text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
|
||||
return text
|
||||
|
||||
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
|
||||
|
@ -1197,7 +1184,7 @@ class Generator:
|
|||
def tablesample_sql(
|
||||
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
|
||||
) -> str:
|
||||
if self.alias_post_tablesample and expression.this.alias:
|
||||
if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
|
||||
table = expression.this.copy()
|
||||
table.set("alias", None)
|
||||
this = self.sql(table)
|
||||
|
@ -1372,7 +1359,15 @@ class Generator:
|
|||
|
||||
def limit_sql(self, expression: exp.Limit) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
|
||||
args = ", ".join(
|
||||
sql
|
||||
for sql in (
|
||||
self.sql(expression, "offset"),
|
||||
self.sql(expression, "expression"),
|
||||
)
|
||||
if sql
|
||||
)
|
||||
return f"{this}{self.seg('LIMIT')} {args}"
|
||||
|
||||
def offset_sql(self, expression: exp.Offset) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1418,10 +1413,10 @@ class Generator:
|
|||
def literal_sql(self, expression: exp.Literal) -> str:
|
||||
text = expression.this or ""
|
||||
if expression.is_string:
|
||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
|
||||
if self.pretty:
|
||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||
text = f"{self.quote_start}{text}{self.quote_end}"
|
||||
text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
|
||||
return text
|
||||
|
||||
def loaddata_sql(self, expression: exp.LoadData) -> str:
|
||||
|
@ -1463,9 +1458,9 @@ class Generator:
|
|||
|
||||
nulls_first = expression.args.get("nulls_first")
|
||||
nulls_last = not nulls_first
|
||||
nulls_are_large = self.null_ordering == "nulls_are_large"
|
||||
nulls_are_small = self.null_ordering == "nulls_are_small"
|
||||
nulls_are_last = self.null_ordering == "nulls_are_last"
|
||||
nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
|
||||
nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
|
||||
nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
|
||||
|
||||
sort_order = " DESC" if desc else ""
|
||||
nulls_sort_change = ""
|
||||
|
@ -1521,7 +1516,7 @@ class Generator:
|
|||
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
|
||||
|
||||
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
|
||||
|
||||
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
|
||||
limit = exp.Limit(expression=limit.args.get("count"))
|
||||
|
@ -1540,12 +1535,19 @@ class Generator:
|
|||
self.sql(expression, "having"),
|
||||
*self.after_having_modifiers(expression),
|
||||
self.sql(expression, "order"),
|
||||
self.sql(expression, "offset") if fetch else self.sql(limit),
|
||||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
*self.offset_limit_modifiers(expression, fetch, limit),
|
||||
*self.after_limit_modifiers(expression),
|
||||
sep="",
|
||||
)
|
||||
|
||||
def offset_limit_modifiers(
|
||||
self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
|
||||
) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "offset") if fetch else self.sql(limit),
|
||||
self.sql(limit) if fetch else self.sql(expression, "offset"),
|
||||
]
|
||||
|
||||
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
|
||||
return [
|
||||
self.sql(expression, "qualify"),
|
||||
|
@ -1634,7 +1636,7 @@ class Generator:
|
|||
def unnest_sql(self, expression: exp.Unnest) -> str:
|
||||
args = self.expressions(expression, flat=True)
|
||||
alias = expression.args.get("alias")
|
||||
if alias and self.unnest_column_only:
|
||||
if alias and self.UNNEST_COLUMN_ONLY:
|
||||
columns = alias.columns
|
||||
alias = self.sql(columns[0]) if columns else ""
|
||||
else:
|
||||
|
@ -1697,7 +1699,7 @@ class Generator:
|
|||
return f"{this} BETWEEN {low} AND {high}"
|
||||
|
||||
def bracket_sql(self, expression: exp.Bracket) -> str:
|
||||
expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
|
||||
expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
|
||||
expressions_sql = ", ".join(self.sql(e) for e in expressions)
|
||||
|
||||
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
|
||||
|
@ -1729,7 +1731,7 @@ class Generator:
|
|||
|
||||
statements.append("END")
|
||||
|
||||
if self.pretty and self.text_width(statements) > self._max_text_width:
|
||||
if self.pretty and self.text_width(statements) > self.max_text_width:
|
||||
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
|
||||
|
||||
return " ".join(statements)
|
||||
|
@ -1759,10 +1761,11 @@ class Generator:
|
|||
else:
|
||||
return self.func("TRIM", expression.this, expression.expression)
|
||||
|
||||
def concat_sql(self, expression: exp.Concat) -> str:
|
||||
if len(expression.expressions) == 1:
|
||||
return self.sql(expression.expressions[0])
|
||||
return self.function_fallback_sql(expression)
|
||||
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
|
||||
expressions = expression.expressions
|
||||
if self.STRICT_STRING_CONCAT:
|
||||
expressions = (exp.cast(e, "text") for e in expressions)
|
||||
return self.func("CONCAT", *expressions)
|
||||
|
||||
def check_sql(self, expression: exp.Check) -> str:
|
||||
this = self.sql(expression, key="this")
|
||||
|
@ -1785,9 +1788,7 @@ class Generator:
|
|||
return f"PRIMARY KEY ({expressions}){options}"
|
||||
|
||||
def if_sql(self, expression: exp.If) -> str:
|
||||
return self.case_sql(
|
||||
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
|
||||
)
|
||||
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
|
||||
|
||||
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
|
||||
modifier = expression.args.get("modifier")
|
||||
|
@ -1798,7 +1799,6 @@ class Generator:
|
|||
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
|
||||
|
||||
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
|
||||
expressions = self.expressions(expression)
|
||||
null_handling = expression.args.get("null_handling")
|
||||
null_handling = f" {null_handling}" if null_handling else ""
|
||||
unique_keys = expression.args.get("unique_keys")
|
||||
|
@ -1811,7 +1811,11 @@ class Generator:
|
|||
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
|
||||
encoding = self.sql(expression, "encoding")
|
||||
encoding = f" ENCODING {encoding}" if encoding else ""
|
||||
return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
|
||||
return self.func(
|
||||
"JSON_OBJECT",
|
||||
*expression.expressions,
|
||||
suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
|
||||
)
|
||||
|
||||
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
|
@ -1930,7 +1934,7 @@ class Generator:
|
|||
for i, e in enumerate(expression.flatten(unnest=False))
|
||||
)
|
||||
|
||||
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
|
||||
sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
|
||||
return f"{sep}{op} ".join(sqls)
|
||||
|
||||
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
|
||||
|
@ -2093,6 +2097,11 @@ class Generator:
|
|||
def dpipe_sql(self, expression: exp.DPipe) -> str:
|
||||
return self.binary(expression, "||")
|
||||
|
||||
def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
|
||||
if self.STRICT_STRING_CONCAT:
|
||||
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
|
||||
return self.dpipe_sql(expression)
|
||||
|
||||
def div_sql(self, expression: exp.Div) -> str:
|
||||
return self.binary(expression, "/")
|
||||
|
||||
|
@ -2127,7 +2136,7 @@ class Generator:
|
|||
return self.binary(expression, "ILIKE ANY")
|
||||
|
||||
def is_sql(self, expression: exp.Is) -> str:
|
||||
if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
|
||||
if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
|
||||
return self.sql(
|
||||
expression.this if expression.expression.this else exp.not_(expression.this)
|
||||
)
|
||||
|
@ -2197,12 +2206,18 @@ class Generator:
|
|||
|
||||
return self.func(expression.sql_name(), *args)
|
||||
|
||||
def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
|
||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||
def func(
|
||||
self,
|
||||
name: str,
|
||||
*args: t.Optional[exp.Expression | str],
|
||||
prefix: str = "(",
|
||||
suffix: str = ")",
|
||||
) -> str:
|
||||
return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
|
||||
|
||||
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
|
||||
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
|
||||
if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
|
||||
if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
|
||||
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
|
||||
return ", ".join(arg_sqls)
|
||||
|
||||
|
@ -2210,7 +2225,9 @@ class Generator:
|
|||
return sum(len(arg) for arg in args)
|
||||
|
||||
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
|
||||
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
|
||||
return format_time(
|
||||
self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
|
||||
)
|
||||
|
||||
def expressions(
|
||||
self,
|
||||
|
@ -2242,7 +2259,7 @@ class Generator:
|
|||
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
|
||||
|
||||
if self.pretty:
|
||||
if self._leading_comma:
|
||||
if self.leading_comma:
|
||||
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
|
||||
else:
|
||||
result_sqls.append(
|
||||
|
|
|
@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
|
|||
return expression
|
||||
|
||||
|
||||
def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
|
||||
def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
|
||||
"""
|
||||
Sorts a given directed acyclic graph in topological order.
|
||||
|
||||
|
@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
|
|||
"""
|
||||
result = []
|
||||
|
||||
def visit(node: T, visited: t.Set[T]) -> None:
|
||||
if node in result:
|
||||
return
|
||||
if node in visited:
|
||||
for node, deps in tuple(dag.items()):
|
||||
for dep in deps:
|
||||
if not dep in dag:
|
||||
dag[dep] = set()
|
||||
|
||||
while dag:
|
||||
current = {node for node, deps in dag.items() if not deps}
|
||||
|
||||
if not current:
|
||||
raise ValueError("Cycle error")
|
||||
|
||||
visited.add(node)
|
||||
for node in current:
|
||||
dag.pop(node)
|
||||
|
||||
for dep in dag.get(node, []):
|
||||
visit(dep, visited)
|
||||
for deps in dag.values():
|
||||
deps -= current
|
||||
|
||||
visited.remove(node)
|
||||
result.append(node)
|
||||
|
||||
for node in dag:
|
||||
visit(node, set())
|
||||
result.extend(sorted(current)) # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -1,13 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot._typing import E
|
||||
from sqlglot.helper import ensure_list, subclasses
|
||||
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||
from sqlglot.schema import ensure_schema
|
||||
from sqlglot.schema import Schema, ensure_schema
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
B = t.TypeVar("B", bound=exp.Binary)
|
||||
|
||||
|
||||
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
||||
def annotate_types(
|
||||
expression: E,
|
||||
schema: t.Optional[t.Dict | Schema] = None,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
) -> E:
|
||||
"""
|
||||
Recursively infer & annotate types in an expression syntax tree against a schema.
|
||||
Assumes that we've already executed the optimizer's qualify_columns step.
|
||||
Infers the types of an expression, annotating its AST accordingly.
|
||||
|
||||
Example:
|
||||
>>> import sqlglot
|
||||
|
@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
<Type.DOUBLE: 'DOUBLE'>
|
||||
|
||||
Args:
|
||||
expression (sqlglot.Expression): Expression to annotate.
|
||||
schema (dict|sqlglot.optimizer.Schema): Database schema.
|
||||
annotators (dict): Maps expression type to corresponding annotation function.
|
||||
coerces_to (dict): Maps expression type to set of types that it can be coerced into.
|
||||
expression: Expression to annotate.
|
||||
schema: Database schema.
|
||||
annotators: Maps expression type to corresponding annotation function.
|
||||
coerces_to: Maps expression type to set of types that it can be coerced into.
|
||||
|
||||
Returns:
|
||||
sqlglot.Expression: expression annotated with types
|
||||
The expression annotated with types.
|
||||
"""
|
||||
|
||||
schema = ensure_schema(schema)
|
||||
|
@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
|
|||
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
|
||||
|
||||
|
||||
class TypeAnnotator:
|
||||
ANNOTATORS = {
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_unary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Unary)
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, expr: self._annotate_binary(expr)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
|
||||
exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
|
||||
exp.Alias: lambda self, expr: self._annotate_unary(expr),
|
||||
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Literal: lambda self, expr: self._annotate_literal(expr),
|
||||
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
|
||||
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
|
||||
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
|
||||
exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BIGINT
|
||||
),
|
||||
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.Sum: lambda self, expr: self._annotate_by_args(
|
||||
expr, "this", "expressions", promote=True
|
||||
),
|
||||
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.CurrentTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATETIME
|
||||
),
|
||||
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampSub: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
|
||||
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
|
||||
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
|
||||
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
|
||||
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
|
||||
exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
|
||||
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
|
||||
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
|
||||
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.RegexpLike: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.BOOLEAN
|
||||
),
|
||||
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.StrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DATE
|
||||
),
|
||||
exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
|
||||
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
|
||||
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.UnixToTime: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.TIMESTAMP
|
||||
),
|
||||
exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.VARCHAR
|
||||
),
|
||||
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
|
||||
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
|
||||
exp.VariancePop: lambda self, expr: self._annotate_with_type(
|
||||
expr, exp.DataType.Type.DOUBLE
|
||||
),
|
||||
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
|
||||
}
|
||||
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
|
||||
return lambda self, e: self._annotate_with_type(e, data_type)
|
||||
|
||||
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
COERCES_TO = {
|
||||
# CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
|
||||
exp.DataType.Type.TEXT: set(),
|
||||
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
|
||||
exp.DataType.Type.NCHAR: {
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
|
||||
class _TypeAnnotator(type):
|
||||
def __new__(cls, clsname, bases, attrs):
|
||||
klass = super().__new__(cls, clsname, bases, attrs)
|
||||
|
||||
# Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
|
||||
# https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
|
||||
text_precedence = (
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
exp.DataType.Type.CHAR: {
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NCHAR,
|
||||
exp.DataType.Type.VARCHAR,
|
||||
exp.DataType.Type.NVARCHAR,
|
||||
exp.DataType.Type.TEXT,
|
||||
},
|
||||
# TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
|
||||
exp.DataType.Type.DOUBLE: set(),
|
||||
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.CHAR,
|
||||
)
|
||||
numeric_precedence = (
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.SMALLINT: {
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.DataType.Type.SMALLINT,
|
||||
exp.DataType.Type.INT,
|
||||
exp.DataType.Type.BIGINT,
|
||||
exp.DataType.Type.DECIMAL,
|
||||
exp.DataType.Type.FLOAT,
|
||||
exp.DataType.Type.DOUBLE,
|
||||
},
|
||||
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
|
||||
exp.DataType.Type.TIMESTAMPLTZ: set(),
|
||||
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TINYINT,
|
||||
)
|
||||
timelike_precedence = (
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
exp.DataType.Type.DATETIME,
|
||||
exp.DataType.Type.DATE,
|
||||
)
|
||||
|
||||
for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
|
||||
coerces_to = set()
|
||||
for data_type in type_precedence:
|
||||
klass.COERCES_TO[data_type] = coerces_to.copy()
|
||||
coerces_to |= {data_type}
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
class TypeAnnotator(metaclass=_TypeAnnotator):
|
||||
TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
|
||||
exp.DataType.Type.BIGINT: {
|
||||
exp.ApproxDistinct,
|
||||
exp.ArraySize,
|
||||
exp.Count,
|
||||
exp.Length,
|
||||
},
|
||||
exp.DataType.Type.BOOLEAN: {
|
||||
exp.Between,
|
||||
exp.Boolean,
|
||||
exp.In,
|
||||
exp.RegexpLike,
|
||||
},
|
||||
exp.DataType.Type.DATE: {
|
||||
exp.DataType.Type.DATETIME,
|
||||
exp.DataType.Type.TIMESTAMP,
|
||||
exp.DataType.Type.TIMESTAMPTZ,
|
||||
exp.DataType.Type.TIMESTAMPLTZ,
|
||||
exp.CurrentDate,
|
||||
exp.Date,
|
||||
exp.DateAdd,
|
||||
exp.DateStrToDate,
|
||||
exp.DateSub,
|
||||
exp.DateTrunc,
|
||||
exp.DiToDate,
|
||||
exp.StrToDate,
|
||||
exp.TimeStrToDate,
|
||||
exp.TsOrDsToDate,
|
||||
},
|
||||
exp.DataType.Type.DATETIME: {
|
||||
exp.CurrentDatetime,
|
||||
exp.DatetimeAdd,
|
||||
exp.DatetimeSub,
|
||||
},
|
||||
exp.DataType.Type.DOUBLE: {
|
||||
exp.ApproxQuantile,
|
||||
exp.Avg,
|
||||
exp.Exp,
|
||||
exp.Ln,
|
||||
exp.Log,
|
||||
exp.Log2,
|
||||
exp.Log10,
|
||||
exp.Pow,
|
||||
exp.Quantile,
|
||||
exp.Round,
|
||||
exp.SafeDivide,
|
||||
exp.Sqrt,
|
||||
exp.Stddev,
|
||||
exp.StddevPop,
|
||||
exp.StddevSamp,
|
||||
exp.Variance,
|
||||
exp.VariancePop,
|
||||
},
|
||||
exp.DataType.Type.INT: {
|
||||
exp.Ceil,
|
||||
exp.DateDiff,
|
||||
exp.DatetimeDiff,
|
||||
exp.Extract,
|
||||
exp.TimestampDiff,
|
||||
exp.TimeDiff,
|
||||
exp.DateToDi,
|
||||
exp.Floor,
|
||||
exp.Levenshtein,
|
||||
exp.StrPosition,
|
||||
exp.TsOrDiToDi,
|
||||
},
|
||||
exp.DataType.Type.TIMESTAMP: {
|
||||
exp.CurrentTime,
|
||||
exp.CurrentTimestamp,
|
||||
exp.StrToTime,
|
||||
exp.TimeAdd,
|
||||
exp.TimeStrToTime,
|
||||
exp.TimeSub,
|
||||
exp.TimestampAdd,
|
||||
exp.TimestampSub,
|
||||
exp.UnixToTime,
|
||||
},
|
||||
exp.DataType.Type.TINYINT: {
|
||||
exp.Day,
|
||||
exp.Month,
|
||||
exp.Week,
|
||||
exp.Year,
|
||||
},
|
||||
exp.DataType.Type.VARCHAR: {
|
||||
exp.ArrayConcat,
|
||||
exp.Concat,
|
||||
exp.ConcatWs,
|
||||
exp.DateToDateStr,
|
||||
exp.GroupConcat,
|
||||
exp.Initcap,
|
||||
exp.Lower,
|
||||
exp.SafeConcat,
|
||||
exp.Substring,
|
||||
exp.TimeToStr,
|
||||
exp.TimeToTimeStr,
|
||||
exp.Trim,
|
||||
exp.TsOrDsToDateStr,
|
||||
exp.UnixToStr,
|
||||
exp.UnixToTimeStr,
|
||||
exp.Upper,
|
||||
},
|
||||
}
|
||||
|
||||
TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
|
||||
ANNOTATORS = {
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_unary(e)
|
||||
for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
|
||||
},
|
||||
**{
|
||||
expr_type: lambda self, e: self._annotate_binary(e)
|
||||
for expr_type in subclasses(exp.__name__, exp.Binary)
|
||||
},
|
||||
**{
|
||||
expr_type: _annotate_with_type_lambda(data_type)
|
||||
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
|
||||
for expr_type in expressions
|
||||
},
|
||||
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
|
||||
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
|
||||
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
|
||||
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
|
||||
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
|
||||
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
|
||||
exp.Literal: lambda self, e: self._annotate_literal(e),
|
||||
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
|
||||
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
|
||||
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
|
||||
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
|
||||
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
|
||||
}
|
||||
|
||||
def __init__(self, schema=None, annotators=None, coerces_to=None):
|
||||
# Specifies what types a given type can be coerced into (autofilled)
|
||||
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: Schema,
|
||||
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
|
||||
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
|
||||
) -> None:
|
||||
self.schema = schema
|
||||
self.annotators = annotators or self.ANNOTATORS
|
||||
self.coerces_to = coerces_to or self.COERCES_TO
|
||||
|
||||
def annotate(self, expression):
|
||||
if isinstance(expression, self.TRAVERSABLES):
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
def annotate(self, expression: E) -> E:
|
||||
for scope in traverse_scope(expression):
|
||||
selects = {}
|
||||
for name, source in scope.sources.items():
|
||||
if not isinstance(source, Scope):
|
||||
continue
|
||||
if isinstance(source.expression, exp.UDTF):
|
||||
values = []
|
||||
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
else:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
if isinstance(source.expression, exp.Lateral):
|
||||
if isinstance(source.expression.this, exp.Explode):
|
||||
values = [source.expression.this.this]
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
values = source.expression.expressions[0].expressions
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
col.type = selects[col.table][col.name].type
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
selects[name] = {
|
||||
alias: column
|
||||
for alias, column in zip(
|
||||
source.expression.alias_column_names,
|
||||
values,
|
||||
)
|
||||
}
|
||||
else:
|
||||
selects[name] = {
|
||||
select.alias_or_name: select for select in source.expression.selects
|
||||
}
|
||||
|
||||
# First annotate the current scope's column references
|
||||
for col in scope.columns:
|
||||
if not col.table:
|
||||
continue
|
||||
|
||||
source = scope.sources.get(col.table)
|
||||
if isinstance(source, exp.Table):
|
||||
col.type = self.schema.get_column_type(source, col)
|
||||
elif source and col.table in selects and col.name in selects[col.table]:
|
||||
col.type = selects[col.table][col.name].type
|
||||
|
||||
# Then (possibly) annotate the remaining expressions in the scope
|
||||
self._maybe_annotate(scope.expression)
|
||||
|
||||
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
|
||||
|
||||
def _maybe_annotate(self, expression):
|
||||
def _maybe_annotate(self, expression: E) -> E:
|
||||
if expression.type:
|
||||
return expression # We've already inferred the expression's type
|
||||
|
||||
|
@ -312,13 +290,15 @@ class TypeAnnotator:
|
|||
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
|
||||
)
|
||||
|
||||
def _annotate_args(self, expression):
|
||||
def _annotate_args(self, expression: E) -> E:
|
||||
for _, value in expression.iter_expressions():
|
||||
self._maybe_annotate(value)
|
||||
|
||||
return expression
|
||||
|
||||
def _maybe_coerce(self, type1, type2):
|
||||
def _maybe_coerce(
|
||||
self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
|
||||
) -> exp.DataType.Type:
|
||||
# We propagate the NULL / UNKNOWN types upwards if found
|
||||
if isinstance(type1, exp.DataType):
|
||||
type1 = type1.this
|
||||
|
@ -330,9 +310,14 @@ class TypeAnnotator:
|
|||
if exp.DataType.Type.UNKNOWN in (type1, type2):
|
||||
return exp.DataType.Type.UNKNOWN
|
||||
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1
|
||||
return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
|
||||
|
||||
def _annotate_binary(self, expression):
|
||||
# Note: the following "no_type_check" decorators were added because mypy was yelling due
|
||||
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
|
||||
# This is a known mypy issue: https://github.com/python/mypy/issues/3004
|
||||
|
||||
@t.no_type_check
|
||||
def _annotate_binary(self, expression: B) -> B:
|
||||
self._annotate_args(expression)
|
||||
|
||||
left_type = expression.left.type.this
|
||||
|
@ -354,7 +339,8 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_unary(self, expression):
|
||||
@t.no_type_check
|
||||
def _annotate_unary(self, expression: E) -> E:
|
||||
self._annotate_args(expression)
|
||||
|
||||
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
|
||||
|
@ -364,7 +350,8 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_literal(self, expression):
|
||||
@t.no_type_check
|
||||
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
|
||||
if expression.is_string:
|
||||
expression.type = exp.DataType.Type.VARCHAR
|
||||
elif expression.is_int:
|
||||
|
@ -374,13 +361,16 @@ class TypeAnnotator:
|
|||
|
||||
return expression
|
||||
|
||||
def _annotate_with_type(self, expression, target_type):
|
||||
@t.no_type_check
|
||||
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
|
||||
expression.type = target_type
|
||||
return self._annotate_args(expression)
|
||||
|
||||
def _annotate_by_args(self, expression, *args, promote=False):
|
||||
@t.no_type_check
|
||||
def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
|
||||
self._annotate_args(expression)
|
||||
expressions = []
|
||||
|
||||
expressions: t.List[exp.Expression] = []
|
||||
for arg in args:
|
||||
arg_expr = expression.args.get(arg)
|
||||
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
|
||||
|
|
|
@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
|
|||
|
||||
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
|
||||
node = exp.Concat(this=node.this, expression=node.expression)
|
||||
node = exp.Concat(expressions=[node.left, node.right])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def eliminate_joins(expression):
|
|||
|
||||
# Reverse the joins so we can remove chains of unused joins
|
||||
for join in reversed(joins):
|
||||
alias = join.this.alias_or_name
|
||||
alias = join.alias_or_name
|
||||
if _should_eliminate_join(scope, join, alias):
|
||||
join.pop()
|
||||
scope.remove_source(alias)
|
||||
|
@ -126,7 +126,7 @@ def join_condition(join):
|
|||
tuple[list[str], list[str], exp.Expression]:
|
||||
Tuple of (source key, join key, remaining predicate)
|
||||
"""
|
||||
name = join.this.alias_or_name
|
||||
name = join.alias_or_name
|
||||
on = (join.args.get("on") or exp.true()).copy()
|
||||
source_key = []
|
||||
join_key = []
|
||||
|
|
|
@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
|
|||
source.replace(
|
||||
exp.select("*")
|
||||
.from_(
|
||||
alias(source, source.name or source.alias, table=True),
|
||||
alias(source, source.alias_or_name, table=True),
|
||||
copy=False,
|
||||
)
|
||||
.subquery(source.alias, copy=False)
|
||||
|
|
|
@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
if not isinstance(from_or_join, exp.Join):
|
||||
return False
|
||||
|
||||
alias = from_or_join.this.alias_or_name
|
||||
alias = from_or_join.alias_or_name
|
||||
|
||||
on = from_or_join.args.get("on")
|
||||
if not on:
|
||||
|
@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
|
|||
"""
|
||||
|
||||
new_joins = []
|
||||
comma_joins = inner_scope.expression.args.get("from").expressions[1:]
|
||||
for subquery in comma_joins:
|
||||
new_joins.append(exp.Join(this=subquery, kind="CROSS"))
|
||||
outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
|
||||
|
||||
joins = inner_scope.expression.args.get("joins") or []
|
||||
for join in joins:
|
||||
|
@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
|
|||
if source == from_or_join.alias_or_name:
|
||||
break
|
||||
|
||||
if set(exp.column_table_names(where.this)) <= sources:
|
||||
if exp.column_table_names(where.this) <= sources:
|
||||
from_or_join.on(where.this, copy=False)
|
||||
from_or_join.set("on", from_or_join.args.get("on"))
|
||||
return
|
||||
|
||||
expression.where(where.this, copy=False)
|
||||
expression.set("where", expression.args.get("where"))
|
||||
|
||||
|
||||
def _merge_order(outer_scope, inner_scope):
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from sqlglot import exp
|
||||
from sqlglot.helper import tsort
|
||||
|
||||
|
@ -13,25 +17,28 @@ def optimize_joins(expression):
|
|||
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
|
||||
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
|
||||
"""
|
||||
|
||||
for select in expression.find_all(exp.Select):
|
||||
references = {}
|
||||
cross_joins = []
|
||||
|
||||
for join in select.args.get("joins", []):
|
||||
name = join.this.alias_or_name
|
||||
tables = other_table_names(join, name)
|
||||
tables = other_table_names(join)
|
||||
|
||||
if tables:
|
||||
for table in tables:
|
||||
references[table] = references.get(table, []) + [join]
|
||||
else:
|
||||
cross_joins.append((name, join))
|
||||
cross_joins.append((join.alias_or_name, join))
|
||||
|
||||
for name, join in cross_joins:
|
||||
for dep in references.get(name, []):
|
||||
on = dep.args["on"]
|
||||
|
||||
if isinstance(on, exp.Connector):
|
||||
if len(other_table_names(dep)) < 2:
|
||||
continue
|
||||
|
||||
for predicate in on.flatten():
|
||||
if name in exp.column_table_names(predicate):
|
||||
predicate.replace(exp.true())
|
||||
|
@ -47,17 +54,12 @@ def reorder_joins(expression):
|
|||
Reorder joins by topological sort order based on predicate references.
|
||||
"""
|
||||
for from_ in expression.find_all(exp.From):
|
||||
head = from_.this
|
||||
parent = from_.parent
|
||||
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {head.alias_or_name: []}
|
||||
|
||||
for name, join in joins.items():
|
||||
dag[name] = other_table_names(join, name)
|
||||
|
||||
joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
|
||||
dag = {name: other_table_names(join) for name, join in joins.items()}
|
||||
parent.set(
|
||||
"joins",
|
||||
[joins[name] for name in tsort(dag) if name != head.alias_or_name],
|
||||
[joins[name] for name in tsort(dag) if name != from_.alias_or_name],
|
||||
)
|
||||
return expression
|
||||
|
||||
|
@ -75,9 +77,6 @@ def normalize(expression):
|
|||
return expression
|
||||
|
||||
|
||||
def other_table_names(join, exclude):
|
||||
return [
|
||||
name
|
||||
for name in (exp.column_table_names(join.args.get("on") or exp.true()))
|
||||
if name != exclude
|
||||
]
|
||||
def other_table_names(join: exp.Join) -> t.Set[str]:
|
||||
on = join.args.get("on")
|
||||
return exp.column_table_names(on, join.alias_or_name) if on else set()
|
||||
|
|
|
@ -78,7 +78,7 @@ def optimize(
|
|||
"schema": schema,
|
||||
"dialect": dialect,
|
||||
"isolate_tables": True, # needed for other optimizations to perform well
|
||||
"quote_identifiers": False, # this happens in canonicalize
|
||||
"quote_identifiers": False,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def pushdown_predicates(expression):
|
|||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
for join in select.args.get("joins") or []:
|
||||
name = join.this.alias_or_name
|
||||
name = join.alias_or_name
|
||||
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
|
||||
|
||||
return expression
|
||||
|
@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
|
|||
pushdown_tables = set()
|
||||
|
||||
for a in predicates:
|
||||
a_tables = set(exp.column_table_names(a))
|
||||
a_tables = exp.column_table_names(a)
|
||||
|
||||
for b in predicates:
|
||||
a_tables &= set(exp.column_table_names(b))
|
||||
a_tables &= exp.column_table_names(b)
|
||||
|
||||
pushdown_tables.update(a_tables)
|
||||
|
||||
|
@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
|
|||
tables = exp.column_table_names(predicate)
|
||||
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
|
||||
|
||||
for table in tables:
|
||||
for table in sorted(tables):
|
||||
node, source = sources.get(table) or (None, None)
|
||||
|
||||
# if the predicate is in a where statement we can try to push it down
|
||||
|
|
|
@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
|
|||
|
||||
def qualify_columns(
|
||||
expression: exp.Expression,
|
||||
schema: dict | Schema,
|
||||
schema: t.Dict | Schema,
|
||||
expand_alias_refs: bool = True,
|
||||
infer_schema: t.Optional[bool] = None,
|
||||
) -> exp.Expression:
|
||||
|
@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
|
|||
|
||||
def _expand_using(scope, resolver):
|
||||
joins = list(scope.find_all(exp.Join))
|
||||
names = {join.this.alias for join in joins}
|
||||
names = {join.alias_or_name for join in joins}
|
||||
ordered = [key for key in scope.selected_sources if key not in names]
|
||||
|
||||
# Mapping of automatically joined column names to an ordered set of source names (dict).
|
||||
|
@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
|
|||
if not using:
|
||||
continue
|
||||
|
||||
join_table = join.this.alias_or_name
|
||||
join_table = join.alias_or_name
|
||||
|
||||
columns = {}
|
||||
|
||||
|
|
|
@ -91,11 +91,13 @@ def qualify_tables(
|
|||
)
|
||||
elif isinstance(source, Scope) and source.is_udtf:
|
||||
udtf = source.expression
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
|
||||
table_alias = udtf.args.get("alias") or exp.TableAlias(
|
||||
this=exp.to_identifier(next_alias_name())
|
||||
)
|
||||
udtf.set("alias", table_alias)
|
||||
|
||||
if not table_alias.name:
|
||||
table_alias.set("this", next_alias_name())
|
||||
table_alias.set("this", exp.to_identifier(next_alias_name()))
|
||||
if isinstance(udtf, exp.Values) and not table_alias.columns:
|
||||
for i, e in enumerate(udtf.expressions[0].expressions):
|
||||
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
|
||||
|
|
|
@ -620,7 +620,7 @@ def _traverse_tables(scope):
|
|||
table_name = expression.name
|
||||
source_name = expression.alias_or_name
|
||||
|
||||
if table_name in scope.sources:
|
||||
if table_name in scope.sources and not expression.db:
|
||||
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
|
||||
# it is pivoted, because then we get back a new table and hence a new source.
|
||||
pivots = expression.args.get("pivots")
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -302,7 +302,7 @@ class Join(Step):
|
|||
|
||||
for join in joins:
|
||||
source_key, join_key, condition = join_condition(join)
|
||||
step.joins[join.this.alias_or_name] = {
|
||||
step.joins[join.alias_or_name] = {
|
||||
"side": join.side, # type: ignore
|
||||
"join_key": join_key,
|
||||
"source_key": source_key,
|
||||
|
|
|
@ -285,8 +285,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
|
|||
elif isinstance(column_type, str):
|
||||
return self._to_data_type(column_type.upper(), dialect=dialect)
|
||||
|
||||
raise SchemaError(f"Unknown column type '{column_type}'")
|
||||
|
||||
return exp.DataType.build("unknown")
|
||||
|
||||
def _normalize(self, schema: t.Dict) -> t.Dict:
|
||||
|
|
|
@ -144,6 +144,7 @@ class TokenType(AutoName):
|
|||
VARIANT = auto()
|
||||
OBJECT = auto()
|
||||
INET = auto()
|
||||
ENUM = auto()
|
||||
|
||||
# keywords
|
||||
ALIAS = auto()
|
||||
|
@ -346,6 +347,7 @@ class Token:
|
|||
col: The column that the token ends on.
|
||||
start: The start index of the token.
|
||||
end: The ending index of the token.
|
||||
comments: The comments to attach to the token.
|
||||
"""
|
||||
self.token_type = token_type
|
||||
self.text = text
|
||||
|
@ -391,12 +393,15 @@ class _Tokenizer(type):
|
|||
|
||||
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
|
||||
klass._COMMENTS = dict(
|
||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
||||
for comment in klass.COMMENTS
|
||||
)
|
||||
klass._COMMENTS = {
|
||||
**dict(
|
||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
||||
for comment in klass.COMMENTS
|
||||
),
|
||||
"{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects
|
||||
}
|
||||
|
||||
klass.KEYWORD_TRIE = new_trie(
|
||||
klass._KEYWORD_TRIE = new_trie(
|
||||
key.upper()
|
||||
for key in (
|
||||
*klass.KEYWORDS,
|
||||
|
@ -456,20 +461,22 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
STRING_ESCAPES = ["'"]
|
||||
VAR_SINGLE_TOKENS: t.Set[str] = set()
|
||||
|
||||
# Autofilled
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
|
||||
|
||||
_COMMENTS: t.Dict[str, str] = {}
|
||||
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
|
||||
_IDENTIFIERS: t.Dict[str, str] = {}
|
||||
_IDENTIFIER_ESCAPES: t.Set[str] = set()
|
||||
_QUOTES: t.Dict[str, str] = {}
|
||||
_STRING_ESCAPES: t.Set[str] = set()
|
||||
_KEYWORD_TRIE: t.Dict = {}
|
||||
|
||||
KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
|
||||
KEYWORDS: t.Dict[str, TokenType] = {
|
||||
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
|
||||
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
|
||||
"{{+": TokenType.BLOCK_START,
|
||||
"{{-": TokenType.BLOCK_START,
|
||||
"+}}": TokenType.BLOCK_END,
|
||||
"-}}": TokenType.BLOCK_END,
|
||||
**{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")},
|
||||
**{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")},
|
||||
"/*+": TokenType.HINT,
|
||||
"==": TokenType.EQ,
|
||||
"::": TokenType.DCOLON,
|
||||
|
@ -594,6 +601,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"RECURSIVE": TokenType.RECURSIVE,
|
||||
"REGEXP": TokenType.RLIKE,
|
||||
"REPLACE": TokenType.REPLACE,
|
||||
"RETURNING": TokenType.RETURNING,
|
||||
"REFERENCES": TokenType.REFERENCES,
|
||||
"RIGHT": TokenType.RIGHT,
|
||||
"RLIKE": TokenType.RLIKE,
|
||||
|
@ -732,8 +740,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||
ENCODE: t.Optional[str] = None
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
|
||||
KEYWORD_TRIE: t.Dict = {} # autofilled
|
||||
COMMENTS = ["--", ("/*", "*/")]
|
||||
|
||||
__slots__ = (
|
||||
"sql",
|
||||
|
@ -748,7 +755,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"identifiers_can_start_with_digit",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -894,7 +900,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
char = chars
|
||||
prev_space = False
|
||||
skip = False
|
||||
trie = self.KEYWORD_TRIE
|
||||
trie = self._KEYWORD_TRIE
|
||||
single_token = char in self.SINGLE_TOKENS
|
||||
|
||||
while chars:
|
||||
|
@ -994,7 +1000,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
self._advance()
|
||||
elif self._peek == "." and not decimal:
|
||||
after = self.peek(1)
|
||||
if after.isdigit() or not after.strip():
|
||||
if after.isdigit() or not after.isalpha():
|
||||
decimal = True
|
||||
self._advance()
|
||||
else:
|
||||
|
@ -1013,13 +1019,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
literal += self._peek.upper()
|
||||
self._advance()
|
||||
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
|
||||
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, ""))
|
||||
|
||||
if token_type:
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
self._add(TokenType.DCOLON, "::")
|
||||
return self._add(token_type, literal)
|
||||
elif self.identifiers_can_start_with_digit: # type: ignore
|
||||
elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
|
||||
return self._add(TokenType.VAR)
|
||||
|
||||
self._add(TokenType.NUMBER, number_text)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue