Merging upstream version 20.3.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
2945bcc4f7
commit
4d9376ba93
132 changed files with 55125 additions and 51576 deletions
|
@ -12,7 +12,7 @@ classes as needed.
|
|||
|
||||
### Implementing a custom Dialect
|
||||
|
||||
Consider the following example:
|
||||
Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot:
|
||||
|
||||
```python
|
||||
from sqlglot import exp
|
||||
|
@ -23,9 +23,10 @@ from sqlglot.tokens import Tokenizer, TokenType
|
|||
|
||||
class Custom(Dialect):
|
||||
class Tokenizer(Tokenizer):
|
||||
QUOTES = ["'", '"']
|
||||
IDENTIFIERS = ["`"]
|
||||
QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes
|
||||
IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks
|
||||
|
||||
# Associates certain meaningful words with tokens that capture their intent
|
||||
KEYWORDS = {
|
||||
**Tokenizer.KEYWORDS,
|
||||
"INT64": TokenType.BIGINT,
|
||||
|
@ -33,8 +34,12 @@ class Custom(Dialect):
|
|||
}
|
||||
|
||||
class Generator(Generator):
|
||||
TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
|
||||
# Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL
|
||||
TRANSFORMS = {
|
||||
exp.Array: lambda self, e: f"[{self.expressions(e)}]",
|
||||
}
|
||||
|
||||
# Specifies how AST nodes representing data types should be converted into SQL
|
||||
TYPE_MAPPING = {
|
||||
exp.DataType.Type.TINYINT: "INT64",
|
||||
exp.DataType.Type.SMALLINT: "INT64",
|
||||
|
@ -48,10 +53,9 @@ class Custom(Dialect):
|
|||
}
|
||||
```
|
||||
|
||||
This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
|
||||
delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
|
||||
the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
|
||||
logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
|
||||
The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different
|
||||
specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing
|
||||
dialect implementations in order to understand how their various components can be modified, depending on the use-case.
|
||||
|
||||
----
|
||||
"""
|
||||
|
|
|
@ -215,6 +215,7 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
|
|||
|
||||
|
||||
class BigQuery(Dialect):
|
||||
WEEK_OFFSET = -1
|
||||
UNNEST_COLUMN_ONLY = True
|
||||
SUPPORTS_USER_DEFINED_TYPES = False
|
||||
SUPPORTS_SEMI_ANTI_JOIN = False
|
||||
|
@ -437,11 +438,7 @@ class BigQuery(Dialect):
|
|||
elif isinstance(this, exp.Literal):
|
||||
table_name = this.name
|
||||
|
||||
if (
|
||||
self._curr
|
||||
and self._prev.end == self._curr.start - 1
|
||||
and self._parse_var(any_token=True)
|
||||
):
|
||||
if self._is_connected() and self._parse_var(any_token=True):
|
||||
table_name += self._prev.text
|
||||
|
||||
this = exp.Identifier(this=table_name, quoted=True)
|
||||
|
|
|
@ -83,6 +83,11 @@ class ClickHouse(Dialect):
|
|||
}
|
||||
|
||||
class Parser(parser.Parser):
|
||||
# Tested in ClickHouse's playground, it seems that the following two queries do the same thing
|
||||
# * select x from t1 union all select x from t2 limit 1;
|
||||
# * select x from t1 union all (select x from t2 limit 1);
|
||||
MODIFIERS_ATTACHED_TO_UNION = False
|
||||
|
||||
FUNCTIONS = {
|
||||
**parser.Parser.FUNCTIONS,
|
||||
"ANY": exp.AnyValue.from_arg_list,
|
||||
|
|
|
@ -21,11 +21,14 @@ DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
|
|||
|
||||
|
||||
class Dialects(str, Enum):
|
||||
"""Dialects supported by SQLGLot."""
|
||||
|
||||
DIALECT = ""
|
||||
|
||||
BIGQUERY = "bigquery"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
DATABRICKS = "databricks"
|
||||
DORIS = "doris"
|
||||
DRILL = "drill"
|
||||
DUCKDB = "duckdb"
|
||||
HIVE = "hive"
|
||||
|
@ -43,16 +46,22 @@ class Dialects(str, Enum):
|
|||
TERADATA = "teradata"
|
||||
TRINO = "trino"
|
||||
TSQL = "tsql"
|
||||
Doris = "doris"
|
||||
|
||||
|
||||
class NormalizationStrategy(str, AutoName):
|
||||
"""Specifies the strategy according to which identifiers should be normalized."""
|
||||
|
||||
LOWERCASE = auto() # Unquoted identifiers are lowercased
|
||||
UPPERCASE = auto() # Unquoted identifiers are uppercased
|
||||
CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
|
||||
CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
|
||||
LOWERCASE = auto()
|
||||
"""Unquoted identifiers are lowercased."""
|
||||
|
||||
UPPERCASE = auto()
|
||||
"""Unquoted identifiers are uppercased."""
|
||||
|
||||
CASE_SENSITIVE = auto()
|
||||
"""Always case-sensitive, regardless of quotes."""
|
||||
|
||||
CASE_INSENSITIVE = auto()
|
||||
"""Always case-insensitive, regardless of quotes."""
|
||||
|
||||
|
||||
class _Dialect(type):
|
||||
|
@ -117,6 +126,7 @@ class _Dialect(type):
|
|||
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
|
||||
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
|
||||
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
|
||||
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
|
||||
|
||||
if enum not in ("", "bigquery"):
|
||||
klass.generator_class.SELECT_KINDS = ()
|
||||
|
@ -131,74 +141,84 @@ class _Dialect(type):
|
|||
|
||||
|
||||
class Dialect(metaclass=_Dialect):
|
||||
# Determines the base index offset for arrays
|
||||
INDEX_OFFSET = 0
|
||||
"""Determines the base index offset for arrays."""
|
||||
|
||||
WEEK_OFFSET = 0
|
||||
"""Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
|
||||
|
||||
# If true unnest table aliases are considered only as column aliases
|
||||
UNNEST_COLUMN_ONLY = False
|
||||
"""Determines whether or not `UNNEST` table aliases are treated as column aliases."""
|
||||
|
||||
# Determines whether or not the table alias comes after tablesample
|
||||
ALIAS_POST_TABLESAMPLE = False
|
||||
"""Determines whether or not the table alias comes after tablesample."""
|
||||
|
||||
# Specifies the strategy according to which identifiers should be normalized.
|
||||
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
|
||||
"""Specifies the strategy according to which identifiers should be normalized."""
|
||||
|
||||
# Determines whether or not an unquoted identifier can start with a digit
|
||||
IDENTIFIERS_CAN_START_WITH_DIGIT = False
|
||||
"""Determines whether or not an unquoted identifier can start with a digit."""
|
||||
|
||||
# Determines whether or not the DPIPE token ('||') is a string concatenation operator
|
||||
DPIPE_IS_STRING_CONCAT = True
|
||||
"""Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
|
||||
|
||||
# Determines whether or not CONCAT's arguments must be strings
|
||||
STRICT_STRING_CONCAT = False
|
||||
"""Determines whether or not `CONCAT`'s arguments must be strings."""
|
||||
|
||||
# Determines whether or not user-defined data types are supported
|
||||
SUPPORTS_USER_DEFINED_TYPES = True
|
||||
"""Determines whether or not user-defined data types are supported."""
|
||||
|
||||
# Determines whether or not SEMI/ANTI JOINs are supported
|
||||
SUPPORTS_SEMI_ANTI_JOIN = True
|
||||
"""Determines whether or not `SEMI` or `ANTI` joins are supported."""
|
||||
|
||||
# Determines how function names are going to be normalized
|
||||
NORMALIZE_FUNCTIONS: bool | str = "upper"
|
||||
"""Determines how function names are going to be normalized."""
|
||||
|
||||
# Determines whether the base comes first in the LOG function
|
||||
LOG_BASE_FIRST = True
|
||||
"""Determines whether the base comes first in the `LOG` function."""
|
||||
|
||||
# Indicates the default null ordering method to use if not explicitly set
|
||||
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
|
||||
NULL_ORDERING = "nulls_are_small"
|
||||
"""
|
||||
Indicates the default `NULL` ordering method to use if not explicitly set.
|
||||
Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
|
||||
"""
|
||||
|
||||
# Whether the behavior of a / b depends on the types of a and b.
|
||||
# False means a / b is always float division.
|
||||
# True means a / b is integer division if both a and b are integers.
|
||||
TYPED_DIVISION = False
|
||||
"""
|
||||
Whether the behavior of `a / b` depends on the types of `a` and `b`.
|
||||
False means `a / b` is always float division.
|
||||
True means `a / b` is integer division if both `a` and `b` are integers.
|
||||
"""
|
||||
|
||||
# False means 1 / 0 throws an error.
|
||||
# True means 1 / 0 returns null.
|
||||
SAFE_DIVISION = False
|
||||
"""Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
|
||||
|
||||
# A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
|
||||
CONCAT_COALESCE = False
|
||||
"""A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
|
||||
|
||||
DATE_FORMAT = "'%Y-%m-%d'"
|
||||
DATEINT_FORMAT = "'%Y%m%d'"
|
||||
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
|
||||
|
||||
# Custom time mappings in which the key represents dialect time format
|
||||
# and the value represents a python time format
|
||||
TIME_MAPPING: t.Dict[str, str] = {}
|
||||
"""Associates this dialect's time formats with their equivalent Python `strftime` format."""
|
||||
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
|
||||
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
|
||||
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
|
||||
FORMAT_MAPPING: t.Dict[str, str] = {}
|
||||
"""
|
||||
Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
|
||||
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
|
||||
"""
|
||||
|
||||
# Mapping of an unescaped escape sequence to the corresponding character
|
||||
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
|
||||
"""Mapping of an unescaped escape sequence to the corresponding character."""
|
||||
|
||||
# Columns that are auto-generated by the engine corresponding to this dialect
|
||||
# Such columns may be excluded from SELECT * queries, for example
|
||||
PSEUDOCOLUMNS: t.Set[str] = set()
|
||||
"""
|
||||
Columns that are auto-generated by the engine corresponding to this dialect.
|
||||
For example, such columns may be excluded from `SELECT *` queries.
|
||||
"""
|
||||
|
||||
# --- Autofilled ---
|
||||
|
||||
|
@ -221,13 +241,15 @@ class Dialect(metaclass=_Dialect):
|
|||
IDENTIFIER_START = '"'
|
||||
IDENTIFIER_END = '"'
|
||||
|
||||
# Delimiters for bit, hex and byte literals
|
||||
# Delimiters for bit, hex, byte and unicode literals
|
||||
BIT_START: t.Optional[str] = None
|
||||
BIT_END: t.Optional[str] = None
|
||||
HEX_START: t.Optional[str] = None
|
||||
HEX_END: t.Optional[str] = None
|
||||
BYTE_START: t.Optional[str] = None
|
||||
BYTE_END: t.Optional[str] = None
|
||||
UNICODE_START: t.Optional[str] = None
|
||||
UNICODE_END: t.Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_or_raise(cls, dialect: DialectType) -> Dialect:
|
||||
|
@ -275,6 +297,7 @@ class Dialect(metaclass=_Dialect):
|
|||
def format_time(
|
||||
cls, expression: t.Optional[str | exp.Expression]
|
||||
) -> t.Optional[exp.Expression]:
|
||||
"""Converts a time format in this dialect to its equivalent Python `strftime` format."""
|
||||
if isinstance(expression, str):
|
||||
return exp.Literal.string(
|
||||
# the time formats are quoted
|
||||
|
@ -306,9 +329,9 @@ class Dialect(metaclass=_Dialect):
|
|||
"""
|
||||
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
|
||||
|
||||
For example, an identifier like FoO would be resolved as foo in Postgres, because it
|
||||
For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
|
||||
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
|
||||
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
|
||||
it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
|
||||
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
|
||||
|
||||
There are also dialects like Spark, which are case-insensitive even when quotes are
|
||||
|
@ -356,8 +379,8 @@ class Dialect(metaclass=_Dialect):
|
|||
Args:
|
||||
text: The text to check.
|
||||
identify:
|
||||
"always" or `True`: Always returns true.
|
||||
"safe": True if the identifier is case-insensitive.
|
||||
`"always"` or `True`: Always returns `True`.
|
||||
`"safe"`: Only returns `True` if the identifier is case-insensitive.
|
||||
|
||||
Returns:
|
||||
Whether or not the given text can be identified.
|
||||
|
@ -371,6 +394,14 @@ class Dialect(metaclass=_Dialect):
|
|||
return False
|
||||
|
||||
def quote_identifier(self, expression: E, identify: bool = True) -> E:
|
||||
"""
|
||||
Adds quotes to a given identifier.
|
||||
|
||||
Args:
|
||||
expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
|
||||
identify: If set to `False`, the quotes will only be added if the identifier is deemed
|
||||
"unsafe", with respect to its characters and this dialect's normalization strategy.
|
||||
"""
|
||||
if isinstance(expression, exp.Identifier):
|
||||
name = expression.this
|
||||
expression.set(
|
||||
|
|
|
@ -81,7 +81,6 @@ class Drill(Dialect):
|
|||
class Tokenizer(tokens.Tokenizer):
|
||||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
class Parser(parser.Parser):
|
||||
STRICT_CAST = False
|
||||
|
|
|
@ -84,11 +84,35 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
|
|||
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
|
||||
|
||||
|
||||
def _parse_make_timestamp(args: t.List) -> exp.Expression:
|
||||
if len(args) == 1:
|
||||
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
|
||||
|
||||
return exp.TimestampFromParts(
|
||||
year=seq_get(args, 0),
|
||||
month=seq_get(args, 1),
|
||||
day=seq_get(args, 2),
|
||||
hour=seq_get(args, 3),
|
||||
min=seq_get(args, 4),
|
||||
sec=seq_get(args, 5),
|
||||
)
|
||||
|
||||
|
||||
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
|
||||
args = [
|
||||
f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}"
|
||||
for e in expression.expressions
|
||||
]
|
||||
args: t.List[str] = []
|
||||
for expr in expression.expressions:
|
||||
if isinstance(expr, exp.Alias):
|
||||
key = expr.alias
|
||||
value = expr.this
|
||||
else:
|
||||
key = expr.name or expr.this.name
|
||||
if isinstance(expr, exp.Bracket):
|
||||
value = expr.expressions[0]
|
||||
else:
|
||||
value = expr.expression
|
||||
|
||||
args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
|
||||
|
||||
return f"{{{', '.join(args)}}}"
|
||||
|
||||
|
||||
|
@ -189,9 +213,7 @@ class DuckDB(Dialect):
|
|||
"LIST_REVERSE_SORT": _sort_array_reverse,
|
||||
"LIST_SORT": exp.SortArray.from_arg_list,
|
||||
"LIST_VALUE": exp.Array.from_arg_list,
|
||||
"MAKE_TIMESTAMP": lambda args: exp.UnixToTime(
|
||||
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
|
||||
),
|
||||
"MAKE_TIMESTAMP": _parse_make_timestamp,
|
||||
"MEDIAN": lambda args: exp.PercentileCont(
|
||||
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
|
||||
),
|
||||
|
@ -339,6 +361,7 @@ class DuckDB(Dialect):
|
|||
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
|
||||
exp.Struct: _struct_sql,
|
||||
exp.Timestamp: no_timestamp_sql,
|
||||
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
|
||||
exp.TimestampTrunc: timestamptrunc_sql,
|
||||
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
|
||||
exp.TimeStrToTime: timestrtotime_sql,
|
||||
|
|
|
@ -240,7 +240,6 @@ class Hive(Dialect):
|
|||
QUOTES = ["'", '"']
|
||||
IDENTIFIERS = ["`"]
|
||||
STRING_ESCAPES = ["\\"]
|
||||
ENCODE = "utf-8"
|
||||
|
||||
SINGLE_TOKENS = {
|
||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||
|
|
|
@ -650,7 +650,7 @@ class MySQL(Dialect):
|
|||
exp.Min: min_or_least,
|
||||
exp.Month: _remove_ts_or_ds_to_date(),
|
||||
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
|
||||
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
|
||||
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
|
||||
exp.Pivot: no_pivot_sql,
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
|
|
|
@ -277,6 +277,7 @@ class Postgres(Dialect):
|
|||
"CONSTRAINT TRIGGER": TokenType.COMMAND,
|
||||
"DECLARE": TokenType.COMMAND,
|
||||
"DO": TokenType.COMMAND,
|
||||
"EXEC": TokenType.COMMAND,
|
||||
"HSTORE": TokenType.HSTORE,
|
||||
"JSONB": TokenType.JSONB,
|
||||
"MONEY": TokenType.MONEY,
|
||||
|
|
|
@ -186,6 +186,27 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
|
|||
return ""
|
||||
|
||||
|
||||
def _to_int(expression: exp.Expression) -> exp.Expression:
|
||||
if not expression.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
annotate_types(expression)
|
||||
if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
|
||||
return exp.cast(expression, to=exp.DataType.Type.BIGINT)
|
||||
return expression
|
||||
|
||||
|
||||
def _parse_to_char(args: t.List) -> exp.TimeToStr:
|
||||
fmt = seq_get(args, 1)
|
||||
if isinstance(fmt, exp.Literal):
|
||||
# We uppercase this to match Teradata's format mapping keys
|
||||
fmt.set("this", fmt.this.upper())
|
||||
|
||||
# We use "teradata" on purpose here, because the time formats are different in Presto.
|
||||
# See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char
|
||||
return format_time_lambda(exp.TimeToStr, "teradata")(args)
|
||||
|
||||
|
||||
class Presto(Dialect):
|
||||
INDEX_OFFSET = 1
|
||||
NULL_ORDERING = "nulls_are_last"
|
||||
|
@ -201,6 +222,12 @@ class Presto(Dialect):
|
|||
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
UNICODE_STRINGS = [
|
||||
(prefix + q, q)
|
||||
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
|
||||
for prefix in ("U&", "u&")
|
||||
]
|
||||
|
||||
KEYWORDS = {
|
||||
**tokens.Tokenizer.KEYWORDS,
|
||||
"START": TokenType.BEGIN,
|
||||
|
@ -253,8 +280,9 @@ class Presto(Dialect):
|
|||
"STRPOS": lambda args: exp.StrPosition(
|
||||
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
|
||||
),
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"TO_CHAR": _parse_to_char,
|
||||
"TO_HEX": exp.Hex.from_arg_list,
|
||||
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
|
||||
"TO_UTF8": lambda args: exp.Encode(
|
||||
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
|
||||
),
|
||||
|
@ -315,7 +343,12 @@ class Presto(Dialect):
|
|||
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
|
||||
exp.DateAdd: lambda self, e: self.func(
|
||||
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "day"),
|
||||
_to_int(
|
||||
e.expression,
|
||||
),
|
||||
e.this,
|
||||
),
|
||||
exp.DateDiff: lambda self, e: self.func(
|
||||
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
|
||||
|
@ -325,7 +358,7 @@ class Presto(Dialect):
|
|||
exp.DateSub: lambda self, e: self.func(
|
||||
"DATE_ADD",
|
||||
exp.Literal.string(e.text("unit") or "day"),
|
||||
e.expression * -1,
|
||||
_to_int(e.expression * -1),
|
||||
e.this,
|
||||
),
|
||||
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
|
||||
|
@ -354,6 +387,7 @@ class Presto(Dialect):
|
|||
exp.Right: right_to_substring_sql,
|
||||
exp.SafeDivide: no_safe_divide_sql,
|
||||
exp.Schema: _schema_sql,
|
||||
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
|
||||
exp.Select: transforms.preprocess(
|
||||
[
|
||||
transforms.eliminate_qualify,
|
||||
|
@ -377,6 +411,7 @@ class Presto(Dialect):
|
|||
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
|
||||
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
|
||||
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
|
||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||
exp.TsOrDsAdd: _ts_or_ds_add_sql,
|
||||
|
|
|
@ -293,7 +293,6 @@ class Snowflake(Dialect):
|
|||
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
|
||||
"TIMEDIFF": _parse_datediff,
|
||||
"TIMESTAMPDIFF": _parse_datediff,
|
||||
"TO_ARRAY": exp.Array.from_arg_list,
|
||||
"TO_TIMESTAMP": _parse_to_timestamp,
|
||||
"TO_VARCHAR": exp.ToChar.from_arg_list,
|
||||
"ZEROIFNULL": _zeroifnull_to_if,
|
||||
|
@ -369,36 +368,58 @@ class Snowflake(Dialect):
|
|||
|
||||
return lateral
|
||||
|
||||
def _parse_at_before(self, table: exp.Table) -> exp.Table:
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
|
||||
index = self._index
|
||||
if self._match_texts(("AT", "BEFORE")):
|
||||
this = self._prev.text.upper()
|
||||
kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
and self._match_texts(self.HISTORICAL_DATA_KIND)
|
||||
and self._prev.text.upper()
|
||||
)
|
||||
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
|
||||
|
||||
if expression:
|
||||
self._match_r_paren()
|
||||
when = self.expression(
|
||||
exp.HistoricalData, this=this, kind=kind, expression=expression
|
||||
)
|
||||
table.set("when", when)
|
||||
else:
|
||||
self._retreat(index)
|
||||
|
||||
return table
|
||||
|
||||
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
|
||||
# https://docs.snowflake.com/en/user-guide/querying-stage
|
||||
table: t.Optional[exp.Expression] = None
|
||||
if self._match_text_seq("@"):
|
||||
table_name = "@"
|
||||
while self._curr:
|
||||
self._advance()
|
||||
table_name += self._prev.text
|
||||
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
|
||||
break
|
||||
while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
|
||||
table_name += self._prev.text
|
||||
|
||||
table = exp.var(table_name)
|
||||
elif self._match(TokenType.STRING, advance=False):
|
||||
if self._match(TokenType.STRING, advance=False):
|
||||
table = self._parse_string()
|
||||
elif self._match_text_seq("@", advance=False):
|
||||
table = self._parse_location_path()
|
||||
else:
|
||||
table = None
|
||||
|
||||
if table:
|
||||
file_format = None
|
||||
pattern = None
|
||||
|
||||
if self._match_text_seq("(", "FILE_FORMAT", "=>"):
|
||||
file_format = self._parse_string() or super()._parse_table_parts()
|
||||
if self._match_text_seq(",", "PATTERN", "=>"):
|
||||
self._match(TokenType.L_PAREN)
|
||||
while self._curr and not self._match(TokenType.R_PAREN):
|
||||
if self._match_text_seq("FILE_FORMAT", "=>"):
|
||||
file_format = self._parse_string() or super()._parse_table_parts()
|
||||
elif self._match_text_seq("PATTERN", "=>"):
|
||||
pattern = self._parse_string()
|
||||
self._match_r_paren()
|
||||
else:
|
||||
break
|
||||
|
||||
return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
|
||||
self._match(TokenType.COMMA)
|
||||
|
||||
return super()._parse_table_parts(schema=schema)
|
||||
table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
|
||||
else:
|
||||
table = super()._parse_table_parts(schema=schema)
|
||||
|
||||
return self._parse_at_before(table)
|
||||
|
||||
def _parse_id_var(
|
||||
self,
|
||||
|
@ -438,17 +459,17 @@ class Snowflake(Dialect):
|
|||
|
||||
def _parse_location(self) -> exp.LocationProperty:
|
||||
self._match(TokenType.EQ)
|
||||
return self.expression(exp.LocationProperty, this=self._parse_location_path())
|
||||
|
||||
parts = [self._parse_var(any_token=True)]
|
||||
def _parse_location_path(self) -> exp.Var:
|
||||
parts = [self._advance_any(ignore_reserved=True)]
|
||||
|
||||
while self._match(TokenType.SLASH):
|
||||
if self._curr and self._prev.end + 1 == self._curr.start:
|
||||
parts.append(self._parse_var(any_token=True))
|
||||
else:
|
||||
parts.append(exp.Var(this=""))
|
||||
return self.expression(
|
||||
exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
|
||||
)
|
||||
# We avoid consuming a comma token because external tables like @foo and @bar
|
||||
# can be joined in a query with a comma separator.
|
||||
while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
|
||||
parts.append(self._advance_any(ignore_reserved=True))
|
||||
|
||||
return exp.var("".join(part.text for part in parts if part))
|
||||
|
||||
class Tokenizer(tokens.Tokenizer):
|
||||
STRING_ESCAPES = ["\\", "'"]
|
||||
|
@ -562,6 +583,7 @@ class Snowflake(Dialect):
|
|||
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
|
||||
),
|
||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||
exp.ToArray: rename_func("TO_ARRAY"),
|
||||
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
|
||||
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
|
||||
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
|
||||
|
|
|
@ -12,22 +12,30 @@ class Teradata(Dialect):
|
|||
TYPED_DIVISION = True
|
||||
|
||||
TIME_MAPPING = {
|
||||
"Y": "%Y",
|
||||
"YYYY": "%Y",
|
||||
"YY": "%y",
|
||||
"MMMM": "%B",
|
||||
"MMM": "%b",
|
||||
"DD": "%d",
|
||||
"D": "%-d",
|
||||
"HH": "%H",
|
||||
"H": "%-H",
|
||||
"MM": "%M",
|
||||
"Y4": "%Y",
|
||||
"YYYY": "%Y",
|
||||
"M4": "%B",
|
||||
"M3": "%b",
|
||||
"M": "%-M",
|
||||
"SS": "%S",
|
||||
"MI": "%M",
|
||||
"MM": "%m",
|
||||
"MMM": "%b",
|
||||
"MMMM": "%B",
|
||||
"D": "%-d",
|
||||
"DD": "%d",
|
||||
"D3": "%j",
|
||||
"DDD": "%j",
|
||||
"H": "%-H",
|
||||
"HH": "%H",
|
||||
"HH24": "%H",
|
||||
"S": "%-S",
|
||||
"SS": "%S",
|
||||
"SSSSSS": "%f",
|
||||
"E": "%a",
|
||||
"EE": "%a",
|
||||
"E3": "%a",
|
||||
"E4": "%A",
|
||||
"EEE": "%a",
|
||||
"EEEE": "%A",
|
||||
}
|
||||
|
|
|
@ -701,6 +701,13 @@ class TSQL(Dialect):
|
|||
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
|
||||
}
|
||||
|
||||
def set_operation(self, expression: exp.Union, op: str) -> str:
|
||||
limit = expression.args.get("limit")
|
||||
if limit:
|
||||
return self.sql(expression.limit(limit.pop(), copy=False))
|
||||
|
||||
return super().set_operation(expression, op)
|
||||
|
||||
def setitem_sql(self, expression: exp.SetItem) -> str:
|
||||
this = expression.this
|
||||
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):
|
||||
|
|
|
@ -343,6 +343,9 @@ class PythonExecutor:
|
|||
else:
|
||||
sink.rows = left.rows + right.rows
|
||||
|
||||
if not math.isinf(step.limit):
|
||||
sink.rows = sink.rows[0 : step.limit]
|
||||
|
||||
return self.context({step.name: sink})
|
||||
|
||||
|
||||
|
|
|
@ -1105,14 +1105,7 @@ class Create(DDL):
|
|||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
|
||||
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
|
||||
class Clone(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
"when": False,
|
||||
"kind": False,
|
||||
"shallow": False,
|
||||
"expression": False,
|
||||
"copy": False,
|
||||
}
|
||||
arg_types = {"this": True, "shallow": False, "copy": False}
|
||||
|
||||
|
||||
class Describe(Expression):
|
||||
|
@ -1213,6 +1206,10 @@ class RawString(Condition):
|
|||
pass
|
||||
|
||||
|
||||
class UnicodeString(Condition):
|
||||
arg_types = {"this": True, "escape": False}
|
||||
|
||||
|
||||
class Column(Condition):
|
||||
arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
|
||||
|
||||
|
@ -1967,7 +1964,12 @@ class Offset(Expression):
|
|||
|
||||
|
||||
class Order(Expression):
|
||||
arg_types = {"this": False, "expressions": True}
|
||||
arg_types = {"this": False, "expressions": True, "interpolate": False}
|
||||
|
||||
|
||||
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
|
||||
class WithFill(Expression):
|
||||
arg_types = {"from": False, "to": False, "step": False}
|
||||
|
||||
|
||||
# hive specific sorts
|
||||
|
@ -1985,7 +1987,7 @@ class Sort(Order):
|
|||
|
||||
|
||||
class Ordered(Expression):
|
||||
arg_types = {"this": True, "desc": False, "nulls_first": True}
|
||||
arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False}
|
||||
|
||||
|
||||
class Property(Expression):
|
||||
|
@ -2522,6 +2524,11 @@ class IndexTableHint(Expression):
|
|||
arg_types = {"this": True, "expressions": False, "target": False}
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
|
||||
class HistoricalData(Expression):
|
||||
arg_types = {"this": True, "kind": True, "expression": True}
|
||||
|
||||
|
||||
class Table(Expression):
|
||||
arg_types = {
|
||||
"this": True,
|
||||
|
@ -2538,6 +2545,7 @@ class Table(Expression):
|
|||
"pattern": False,
|
||||
"index": False,
|
||||
"ordinality": False,
|
||||
"when": False,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -4310,6 +4318,11 @@ class Array(Func):
|
|||
is_var_len_args = True
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/to_array
|
||||
class ToArray(Func):
|
||||
pass
|
||||
|
||||
|
||||
# https://docs.snowflake.com/en/sql-reference/functions/to_char
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
|
||||
class ToChar(Func):
|
||||
|
@ -5233,6 +5246,19 @@ class UnixToTimeStr(Func):
|
|||
pass
|
||||
|
||||
|
||||
class TimestampFromParts(Func):
|
||||
"""Constructs a timestamp given its constituent parts."""
|
||||
|
||||
arg_types = {
|
||||
"year": True,
|
||||
"month": True,
|
||||
"day": True,
|
||||
"hour": True,
|
||||
"min": True,
|
||||
"sec": True,
|
||||
}
|
||||
|
||||
|
||||
class Upper(Func):
|
||||
_sql_names = ["UPPER", "UCASE"]
|
||||
|
||||
|
|
|
@ -862,15 +862,7 @@ class Generator:
|
|||
this = self.sql(expression, "this")
|
||||
shallow = "SHALLOW " if expression.args.get("shallow") else ""
|
||||
keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE"
|
||||
this = f"{shallow}{keyword} {this}"
|
||||
when = self.sql(expression, "when")
|
||||
|
||||
if when:
|
||||
kind = self.sql(expression, "kind")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"{this} {when} ({kind} => {expr})"
|
||||
|
||||
return this
|
||||
return f"{shallow}{keyword} {this}"
|
||||
|
||||
def describe_sql(self, expression: exp.Describe) -> str:
|
||||
return f"DESCRIBE {self.sql(expression, 'this')}"
|
||||
|
@ -923,6 +915,14 @@ class Generator:
|
|||
return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
|
||||
return this
|
||||
|
||||
def unicodestring_sql(self, expression: exp.UnicodeString) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
if self.dialect.UNICODE_START:
|
||||
escape = self.sql(expression, "escape")
|
||||
escape = f" UESCAPE {escape}" if escape else ""
|
||||
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
|
||||
return this
|
||||
|
||||
def rawstring_sql(self, expression: exp.RawString) -> str:
|
||||
string = self.escape_str(expression.this.replace("\\", "\\\\"))
|
||||
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
|
||||
|
@ -1400,6 +1400,12 @@ class Generator:
|
|||
target = f" FOR {target}" if target else ""
|
||||
return f"{this}{target} ({self.expressions(expression, flat=True)})"
|
||||
|
||||
def historicaldata_sql(self, expression: exp.HistoricalData) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
kind = self.sql(expression, "kind")
|
||||
expr = self.sql(expression, "expression")
|
||||
return f"{this} ({kind} => {expr})"
|
||||
|
||||
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
|
||||
table = ".".join(
|
||||
self.sql(part)
|
||||
|
@ -1436,6 +1442,10 @@ class Generator:
|
|||
ordinality = f" WITH ORDINALITY{alias}"
|
||||
alias = ""
|
||||
|
||||
when = self.sql(expression, "when")
|
||||
if when:
|
||||
table = f"{table} {when}"
|
||||
|
||||
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
|
||||
|
||||
def tablesample_sql(
|
||||
|
@ -1784,7 +1794,24 @@ class Generator:
|
|||
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
|
||||
this = self.sql(expression, "this")
|
||||
this = f"{this} " if this else this
|
||||
return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
|
||||
order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
|
||||
interpolated_values = [
|
||||
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
|
||||
for named_expression in expression.args.get("interpolate") or []
|
||||
]
|
||||
interpolate = (
|
||||
f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else ""
|
||||
)
|
||||
return f"{order}{interpolate}"
|
||||
|
||||
def withfill_sql(self, expression: exp.WithFill) -> str:
|
||||
from_sql = self.sql(expression, "from")
|
||||
from_sql = f" FROM {from_sql}" if from_sql else ""
|
||||
to_sql = self.sql(expression, "to")
|
||||
to_sql = f" TO {to_sql}" if to_sql else ""
|
||||
step_sql = self.sql(expression, "step")
|
||||
step_sql = f" STEP {step_sql}" if step_sql else ""
|
||||
return f"WITH FILL{from_sql}{to_sql}{step_sql}"
|
||||
|
||||
def cluster_sql(self, expression: exp.Cluster) -> str:
|
||||
return self.op_expressions("CLUSTER BY", expression)
|
||||
|
@ -1826,7 +1853,10 @@ class Generator:
|
|||
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
|
||||
nulls_sort_change = ""
|
||||
|
||||
return f"{this}{sort_order}{nulls_sort_change}"
|
||||
with_fill = self.sql(expression, "with_fill")
|
||||
with_fill = f" {with_fill}" if with_fill else ""
|
||||
|
||||
return f"{this}{sort_order}{nulls_sort_change}{with_fill}"
|
||||
|
||||
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
|
||||
partition = self.partition_by_sql(expression)
|
||||
|
@ -3048,11 +3078,24 @@ class Generator:
|
|||
def operator_sql(self, expression: exp.Operator) -> str:
|
||||
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
|
||||
|
||||
def toarray_sql(self, expression: exp.ToArray) -> str:
|
||||
arg = expression.this
|
||||
if not arg.type:
|
||||
from sqlglot.optimizer.annotate_types import annotate_types
|
||||
|
||||
arg = annotate_types(arg)
|
||||
|
||||
if arg.is_type(exp.DataType.Type.ARRAY):
|
||||
return self.sql(arg)
|
||||
|
||||
cond_for_null = arg.is_(exp.null())
|
||||
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
|
||||
|
||||
def _simplify_unless_literal(self, expression: E) -> E:
|
||||
if not isinstance(expression, exp.Literal):
|
||||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
expression = simplify(expression)
|
||||
expression = simplify(expression, dialect=self.dialect)
|
||||
|
||||
return expression
|
||||
|
||||
|
|
|
@ -95,9 +95,6 @@ def eliminate_subqueries(expression):
|
|||
|
||||
|
||||
def _eliminate(scope, existing_ctes, taken):
|
||||
if scope.is_union:
|
||||
return _eliminate_union(scope, existing_ctes, taken)
|
||||
|
||||
if scope.is_derived_table:
|
||||
return _eliminate_derived_table(scope, existing_ctes, taken)
|
||||
|
||||
|
@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken):
|
|||
return _eliminate_cte(scope, existing_ctes, taken)
|
||||
|
||||
|
||||
def _eliminate_union(scope, existing_ctes, taken):
|
||||
duplicate_cte_alias = existing_ctes.get(scope.expression)
|
||||
|
||||
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
|
||||
|
||||
taken[alias] = scope
|
||||
|
||||
# Try to maintain the selections
|
||||
expressions = scope.expression.selects
|
||||
selects = [
|
||||
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
|
||||
for e in expressions
|
||||
if e.alias_or_name
|
||||
]
|
||||
# If not all selections have an alias, just select *
|
||||
if len(selects) != len(expressions):
|
||||
selects = ["*"]
|
||||
|
||||
scope.expression.replace(
|
||||
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
|
||||
)
|
||||
|
||||
if not duplicate_cte_alias:
|
||||
existing_ctes[scope.expression] = alias
|
||||
return exp.CTE(
|
||||
this=scope.expression,
|
||||
alias=exp.TableAlias(this=exp.to_identifier(alias)),
|
||||
)
|
||||
|
||||
|
||||
def _eliminate_derived_table(scope, existing_ctes, taken):
|
||||
# This makes sure that we don't:
|
||||
# - drop the "pivot" arg from a pivoted subquery
|
||||
|
|
|
@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
for col in inner_projections[selection].find_all(exp.Column)
|
||||
)
|
||||
|
||||
def _is_recursive():
|
||||
# Recursive CTEs look like this:
|
||||
# WITH RECURSIVE cte AS (
|
||||
# SELECT * FROM x <-- inner scope
|
||||
# UNION ALL
|
||||
# SELECT * FROM cte <-- outer scope
|
||||
# )
|
||||
cte = inner_scope.expression.parent
|
||||
node = outer_scope.expression.parent
|
||||
|
||||
while node:
|
||||
if node is cte:
|
||||
return True
|
||||
node = node.parent
|
||||
return False
|
||||
|
||||
return (
|
||||
isinstance(outer_scope.expression, exp.Select)
|
||||
and not outer_scope.expression.is_star
|
||||
|
@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
|
|||
)
|
||||
and not _outer_select_joins_on_inner_select_join()
|
||||
and not _is_a_window_expression_in_unmergable_operation()
|
||||
and not _is_recursive()
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope
|
|||
from sqlglot.optimizer.simplify import simplify
|
||||
|
||||
|
||||
def pushdown_predicates(expression):
|
||||
def pushdown_predicates(expression, dialect=None):
|
||||
"""
|
||||
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
|
||||
|
||||
|
@ -36,7 +36,7 @@ def pushdown_predicates(expression):
|
|||
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
|
||||
selected_sources = {k: (node, source)}
|
||||
break
|
||||
pushdown(where.this, selected_sources, scope_ref_count)
|
||||
pushdown(where.this, selected_sources, scope_ref_count, dialect)
|
||||
|
||||
# joins should only pushdown into itself, not to other joins
|
||||
# so we limit the selected sources to only itself
|
||||
|
@ -44,17 +44,20 @@ def pushdown_predicates(expression):
|
|||
name = join.alias_or_name
|
||||
if name in scope.selected_sources:
|
||||
pushdown(
|
||||
join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
|
||||
join.args.get("on"),
|
||||
{name: scope.selected_sources[name]},
|
||||
scope_ref_count,
|
||||
dialect,
|
||||
)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
def pushdown(condition, sources, scope_ref_count):
|
||||
def pushdown(condition, sources, scope_ref_count, dialect):
|
||||
if not condition:
|
||||
return
|
||||
|
||||
condition = condition.replace(simplify(condition))
|
||||
condition = condition.replace(simplify(condition, dialect=dialect))
|
||||
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
|
||||
|
||||
predicates = list(
|
||||
|
|
|
@ -37,6 +37,7 @@ class Scope:
|
|||
For example:
|
||||
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
|
||||
The LATERAL VIEW EXPLODE gets x as a source.
|
||||
cte_sources (dict[str, Scope]): Sources from CTES
|
||||
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
|
||||
defines a column list of it's alias of this scope, this is that list of columns.
|
||||
For example:
|
||||
|
@ -61,11 +62,14 @@ class Scope:
|
|||
parent=None,
|
||||
scope_type=ScopeType.ROOT,
|
||||
lateral_sources=None,
|
||||
cte_sources=None,
|
||||
):
|
||||
self.expression = expression
|
||||
self.sources = sources or {}
|
||||
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
|
||||
self.lateral_sources = lateral_sources or {}
|
||||
self.cte_sources = cte_sources or {}
|
||||
self.sources.update(self.lateral_sources)
|
||||
self.sources.update(self.cte_sources)
|
||||
self.outer_column_list = outer_column_list or []
|
||||
self.parent = parent
|
||||
self.scope_type = scope_type
|
||||
|
@ -92,13 +96,17 @@ class Scope:
|
|||
self._pivots = None
|
||||
self._references = None
|
||||
|
||||
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
|
||||
def branch(
|
||||
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
|
||||
):
|
||||
"""Branch from the current scope to a new, inner scope"""
|
||||
return Scope(
|
||||
expression=expression.unnest(),
|
||||
sources={**self.cte_sources, **(chain_sources or {})},
|
||||
sources=sources.copy() if sources else None,
|
||||
parent=self,
|
||||
scope_type=scope_type,
|
||||
cte_sources={**self.cte_sources, **(cte_sources or {})},
|
||||
lateral_sources=lateral_sources.copy() if lateral_sources else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -305,20 +313,6 @@ class Scope:
|
|||
|
||||
return self._references
|
||||
|
||||
@property
|
||||
def cte_sources(self):
|
||||
"""
|
||||
Sources that are CTEs.
|
||||
|
||||
Returns:
|
||||
dict[str, Scope]: Mapping of source alias to Scope
|
||||
"""
|
||||
return {
|
||||
alias: scope
|
||||
for alias, scope in self.sources.items()
|
||||
if isinstance(scope, Scope) and scope.is_cte
|
||||
}
|
||||
|
||||
@property
|
||||
def external_columns(self):
|
||||
"""
|
||||
|
@ -515,7 +509,10 @@ def _traverse_scope(scope):
|
|||
elif isinstance(scope.expression, exp.Union):
|
||||
yield from _traverse_union(scope)
|
||||
elif isinstance(scope.expression, exp.Subquery):
|
||||
yield from _traverse_subqueries(scope)
|
||||
if scope.is_root:
|
||||
yield from _traverse_select(scope)
|
||||
else:
|
||||
yield from _traverse_subqueries(scope)
|
||||
elif isinstance(scope.expression, exp.Table):
|
||||
yield from _traverse_tables(scope)
|
||||
elif isinstance(scope.expression, exp.UDTF):
|
||||
|
@ -572,7 +569,7 @@ def _traverse_ctes(scope):
|
|||
for child_scope in _traverse_scope(
|
||||
scope.branch(
|
||||
cte.this,
|
||||
chain_sources=sources,
|
||||
cte_sources=sources,
|
||||
outer_column_list=cte.alias_column_names,
|
||||
scope_type=ScopeType.CTE,
|
||||
)
|
||||
|
@ -584,12 +581,14 @@ def _traverse_ctes(scope):
|
|||
|
||||
if recursive_scope:
|
||||
child_scope.add_source(alias, recursive_scope)
|
||||
child_scope.cte_sources[alias] = recursive_scope
|
||||
|
||||
# append the final child_scope yielded
|
||||
if child_scope:
|
||||
scope.cte_scopes.append(child_scope)
|
||||
|
||||
scope.sources.update(sources)
|
||||
scope.cte_sources.update(sources)
|
||||
|
||||
|
||||
def _is_derived_table(expression: exp.Subquery) -> bool:
|
||||
|
@ -725,7 +724,7 @@ def _traverse_ddl(scope):
|
|||
yield from _traverse_ctes(scope)
|
||||
|
||||
query_scope = scope.branch(
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
|
||||
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
|
||||
)
|
||||
query_scope._collect()
|
||||
query_scope._ctes = scope.ctes + query_scope._ctes
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import itertools
|
||||
|
@ -6,10 +8,17 @@ from collections import deque
|
|||
from decimal import Decimal
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot import Dialect, exp
|
||||
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
|
||||
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
|
||||
]
|
||||
|
||||
# Final means that an expression should not be simplified
|
||||
FINAL = "final"
|
||||
|
||||
|
@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def simplify(expression, constant_propagation=False):
|
||||
def simplify(
|
||||
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
|
||||
):
|
||||
"""
|
||||
Rewrite sqlglot AST to simplify expressions.
|
||||
|
||||
|
@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
|
|||
sqlglot.Expression: simplified expression
|
||||
"""
|
||||
|
||||
dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
# group by expressions cannot be simplified, for example
|
||||
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
|
||||
# the projection must exactly match the group by key
|
||||
for group in expression.find_all(exp.Group):
|
||||
select = group.parent
|
||||
assert select
|
||||
groups = set(group.expressions)
|
||||
group.meta[FINAL] = True
|
||||
|
||||
for e in select.selects:
|
||||
for e in select.expressions:
|
||||
for node, *_ in e.walk():
|
||||
if node in groups:
|
||||
e.meta[FINAL] = True
|
||||
|
@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
|
|||
node = simplify_literals(node, root)
|
||||
node = simplify_equality(node)
|
||||
node = simplify_parens(node)
|
||||
node = simplify_datetrunc_predicate(node)
|
||||
node = simplify_datetrunc(node, dialect)
|
||||
node = sort_comparison(node)
|
||||
|
||||
if root:
|
||||
expression.replace(node)
|
||||
|
@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
|
|||
This is done because comparison simplification is only done on lt/lte/gt/gte.
|
||||
"""
|
||||
if isinstance(expression, exp.Between):
|
||||
return exp.and_(
|
||||
negate = isinstance(expression.parent, exp.Not)
|
||||
|
||||
expression = exp.and_(
|
||||
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
|
||||
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
|
||||
copy=False,
|
||||
)
|
||||
|
||||
if negate:
|
||||
expression = exp.paren(expression, copy=False)
|
||||
|
||||
return expression
|
||||
|
||||
|
||||
COMPLEMENT_COMPARISONS = {
|
||||
exp.LT: exp.GTE,
|
||||
exp.GT: exp.LTE,
|
||||
exp.LTE: exp.GT,
|
||||
exp.GTE: exp.LT,
|
||||
exp.EQ: exp.NEQ,
|
||||
exp.NEQ: exp.EQ,
|
||||
}
|
||||
|
||||
|
||||
def simplify_not(expression):
|
||||
"""
|
||||
Demorgan's Law
|
||||
|
@ -132,10 +163,15 @@ def simplify_not(expression):
|
|||
NOT (x AND y) -> NOT x OR NOT y
|
||||
"""
|
||||
if isinstance(expression, exp.Not):
|
||||
if is_null(expression.this):
|
||||
this = expression.this
|
||||
if is_null(this):
|
||||
return exp.null()
|
||||
if isinstance(expression.this, exp.Paren):
|
||||
condition = expression.this.unnest()
|
||||
if this.__class__ in COMPLEMENT_COMPARISONS:
|
||||
return COMPLEMENT_COMPARISONS[this.__class__](
|
||||
this=this.this, expression=this.expression
|
||||
)
|
||||
if isinstance(this, exp.Paren):
|
||||
condition = this.unnest()
|
||||
if isinstance(condition, exp.And):
|
||||
return exp.or_(
|
||||
exp.not_(condition.left, copy=False),
|
||||
|
@ -150,14 +186,14 @@ def simplify_not(expression):
|
|||
)
|
||||
if is_null(condition):
|
||||
return exp.null()
|
||||
if always_true(expression.this):
|
||||
if always_true(this):
|
||||
return exp.false()
|
||||
if is_false(expression.this):
|
||||
if is_false(this):
|
||||
return exp.true()
|
||||
if isinstance(expression.this, exp.Not):
|
||||
if isinstance(this, exp.Not):
|
||||
# double negation
|
||||
# NOT NOT x -> x
|
||||
return expression.this.this
|
||||
return this.this
|
||||
return expression
|
||||
|
||||
|
||||
|
@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
|
|||
except StopIteration:
|
||||
return expression
|
||||
|
||||
# make sure the comparison is always of the form x > 1 instead of 1 < x
|
||||
if left.__class__ in INVERSE_COMPARISONS and l == ll:
|
||||
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
|
||||
if right.__class__ in INVERSE_COMPARISONS and r == rl:
|
||||
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
|
||||
|
||||
if l.is_number and r.is_number:
|
||||
l = float(l.name)
|
||||
r = float(r.name)
|
||||
|
@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
|
|||
# TODO: create a helper that can be used to detect nested literal expressions such
|
||||
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
|
||||
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
|
||||
pass
|
||||
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
|
||||
l, r = r, l
|
||||
else:
|
||||
continue
|
||||
|
||||
constant_mapping[l] = (id(l), r)
|
||||
constant_mapping[l] = (id(l), r)
|
||||
|
||||
if constant_mapping:
|
||||
for column in find_all_in_scope(expression, exp.Column):
|
||||
|
@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
|
|||
if isinstance(expression, COMPARISONS):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if l.__class__ in INVERSE_OPS:
|
||||
pass
|
||||
elif r.__class__ in INVERSE_OPS:
|
||||
l, r = r, l
|
||||
else:
|
||||
if not l.__class__ in INVERSE_OPS:
|
||||
return expression
|
||||
|
||||
if r.is_number:
|
||||
|
@ -650,7 +670,7 @@ def simplify_coalesce(expression):
|
|||
|
||||
# Find the first constant arg
|
||||
for arg_index, arg in enumerate(coalesce.expressions):
|
||||
if _is_constant(other):
|
||||
if _is_constant(arg):
|
||||
break
|
||||
else:
|
||||
return expression
|
||||
|
@ -752,7 +772,7 @@ def simplify_conditionals(expression):
|
|||
DateRange = t.Tuple[datetime.date, datetime.date]
|
||||
|
||||
|
||||
def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
|
||||
def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
|
||||
"""
|
||||
Get the date range for a DATE_TRUNC equality comparison:
|
||||
|
||||
|
@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
|
|||
Returns:
|
||||
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
|
||||
"""
|
||||
floor = date_floor(date, unit)
|
||||
floor = date_floor(date, unit, dialect)
|
||||
|
||||
if date != floor:
|
||||
# This will always be False, except for NULL values.
|
||||
|
@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
|
|||
|
||||
|
||||
def _datetrunc_eq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
|
@ -790,9 +810,9 @@ def _datetrunc_eq(
|
|||
|
||||
|
||||
def _datetrunc_neq(
|
||||
left: exp.Expression, date: datetime.date, unit: str
|
||||
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
|
||||
) -> t.Optional[exp.Expression]:
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if not drange:
|
||||
return None
|
||||
|
||||
|
@ -803,41 +823,39 @@ def _datetrunc_neq(
|
|||
)
|
||||
|
||||
|
||||
DateTruncBinaryTransform = t.Callable[
|
||||
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
|
||||
]
|
||||
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
|
||||
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
|
||||
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
|
||||
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
|
||||
exp.LT: lambda l, dt, u, d: l
|
||||
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
|
||||
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
|
||||
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
|
||||
exp.EQ: _datetrunc_eq,
|
||||
exp.NEQ: _datetrunc_neq,
|
||||
}
|
||||
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
|
||||
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
|
||||
|
||||
|
||||
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
|
||||
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
|
||||
return isinstance(left, DATETRUNCS) and _is_date_literal(right)
|
||||
|
||||
|
||||
@catch(ModuleNotFoundError, UnsupportedUnit)
|
||||
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
||||
def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
|
||||
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
|
||||
comparison = expression.__class__
|
||||
|
||||
if comparison not in DATETRUNC_COMPARISONS:
|
||||
if isinstance(expression, DATETRUNCS):
|
||||
date = extract_date(expression.this)
|
||||
if date and expression.unit:
|
||||
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
|
||||
elif comparison not in DATETRUNC_COMPARISONS:
|
||||
return expression
|
||||
|
||||
if isinstance(expression, exp.Binary):
|
||||
l, r = expression.left, expression.right
|
||||
|
||||
if _is_datetrunc_predicate(l, r):
|
||||
pass
|
||||
elif _is_datetrunc_predicate(r, l):
|
||||
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
|
||||
l, r = r, l
|
||||
else:
|
||||
if not _is_datetrunc_predicate(l, r):
|
||||
return expression
|
||||
|
||||
l = t.cast(exp.DateTrunc, l)
|
||||
|
@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
if not date:
|
||||
return expression
|
||||
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
|
||||
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
|
||||
elif isinstance(expression, exp.In):
|
||||
l = expression.this
|
||||
rs = expression.expressions
|
||||
|
@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
date = extract_date(r)
|
||||
if not date:
|
||||
return expression
|
||||
drange = _datetrunc_range(date, unit)
|
||||
drange = _datetrunc_range(date, unit, dialect)
|
||||
if drange:
|
||||
ranges.append(drange)
|
||||
|
||||
|
@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
|
|||
return expression
|
||||
|
||||
|
||||
def sort_comparison(expression: exp.Expression) -> exp.Expression:
|
||||
if expression.__class__ in COMPLEMENT_COMPARISONS:
|
||||
l, r = expression.this, expression.expression
|
||||
l_column = isinstance(l, exp.Column)
|
||||
r_column = isinstance(r, exp.Column)
|
||||
l_const = _is_constant(l)
|
||||
r_const = _is_constant(r)
|
||||
|
||||
if (l_column and not r_column) or (r_const and not l_const):
|
||||
return expression
|
||||
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
|
||||
return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
|
||||
this=r, expression=l
|
||||
)
|
||||
return expression
|
||||
|
||||
|
||||
# CROSS joins result in an empty table if the right table is empty.
|
||||
# So we can only simplify certain types of joins to CROSS.
|
||||
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
|
||||
|
@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
|
|||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_floor(d: datetime.date, unit: str) -> datetime.date:
|
||||
def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
|
||||
if unit == "year":
|
||||
return d.replace(month=1, day=1)
|
||||
if unit == "quarter":
|
||||
|
@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
|
|||
return d.replace(month=d.month, day=1)
|
||||
if unit == "week":
|
||||
# Assuming week starts on Monday (0) and ends on Sunday (6)
|
||||
return d - datetime.timedelta(days=d.weekday())
|
||||
return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
|
||||
if unit == "day":
|
||||
return d
|
||||
|
||||
raise UnsupportedUnit(f"Unsupported unit: {unit}")
|
||||
|
||||
|
||||
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
|
||||
floor = date_floor(d, unit)
|
||||
def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
|
||||
floor = date_floor(d, unit, dialect)
|
||||
|
||||
if floor == d:
|
||||
return d
|
||||
|
|
|
@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name):
|
|||
)
|
||||
):
|
||||
column = exp.Max(this=column)
|
||||
elif not isinstance(select.parent, exp.Subquery):
|
||||
return
|
||||
|
||||
_replace(select.parent, column)
|
||||
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
|
||||
|
|
|
@ -568,6 +568,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
|
||||
exp.Table: lambda self: self._parse_table_parts(),
|
||||
exp.TableAlias: lambda self: self._parse_table_alias(),
|
||||
exp.When: lambda self: seq_get(self._parse_when_matched(), 0),
|
||||
exp.Where: lambda self: self._parse_where(),
|
||||
exp.Window: lambda self: self._parse_named_window(),
|
||||
exp.With: lambda self: self._parse_with(),
|
||||
|
@ -635,6 +636,11 @@ class Parser(metaclass=_Parser):
|
|||
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
|
||||
exp.RawString, this=token.text
|
||||
),
|
||||
TokenType.UNICODE_STRING: lambda self, token: self.expression(
|
||||
exp.UnicodeString,
|
||||
this=token.text,
|
||||
escape=self._match_text_seq("UESCAPE") and self._parse_string(),
|
||||
),
|
||||
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
|
||||
}
|
||||
|
||||
|
@ -907,7 +913,7 @@ class Parser(metaclass=_Parser):
|
|||
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
|
||||
|
||||
CLONE_KEYWORDS = {"CLONE", "COPY"}
|
||||
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
|
||||
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
|
||||
|
||||
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
|
||||
OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
|
||||
|
@ -947,6 +953,10 @@ class Parser(metaclass=_Parser):
|
|||
# Whether the TRIM function expects the characters to trim as its first argument
|
||||
TRIM_PATTERN_FIRST = False
|
||||
|
||||
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
|
||||
MODIFIERS_ATTACHED_TO_UNION = True
|
||||
UNION_MODIFIERS = {"order", "limit", "offset"}
|
||||
|
||||
__slots__ = (
|
||||
"error_level",
|
||||
"error_message_context",
|
||||
|
@ -1162,6 +1172,9 @@ class Parser(metaclass=_Parser):
|
|||
def _find_sql(self, start: Token, end: Token) -> str:
|
||||
return self.sql[start.start : end.end + 1]
|
||||
|
||||
def _is_connected(self) -> bool:
|
||||
return self._prev and self._curr and self._prev.end + 1 == self._curr.start
|
||||
|
||||
def _advance(self, times: int = 1) -> None:
|
||||
self._index += times
|
||||
self._curr = seq_get(self._tokens, self._index)
|
||||
|
@ -1404,23 +1417,8 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
if self._match_texts(self.CLONE_KEYWORDS):
|
||||
copy = self._prev.text.lower() == "copy"
|
||||
clone = self._parse_table(schema=True)
|
||||
when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper()
|
||||
clone_kind = (
|
||||
self._match(TokenType.L_PAREN)
|
||||
and self._match_texts(self.CLONE_KINDS)
|
||||
and self._prev.text.upper()
|
||||
)
|
||||
clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
|
||||
self._match(TokenType.R_PAREN)
|
||||
clone = self.expression(
|
||||
exp.Clone,
|
||||
this=clone,
|
||||
when=when,
|
||||
kind=clone_kind,
|
||||
shallow=shallow,
|
||||
expression=clone_expression,
|
||||
copy=copy,
|
||||
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
|
||||
)
|
||||
|
||||
return self.expression(
|
||||
|
@ -2471,13 +2469,7 @@ class Parser(metaclass=_Parser):
|
|||
pattern = None
|
||||
|
||||
define = (
|
||||
self._parse_csv(
|
||||
lambda: self.expression(
|
||||
exp.Alias,
|
||||
alias=self._parse_id_var(any_token=True),
|
||||
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
|
||||
)
|
||||
)
|
||||
self._parse_csv(self._parse_name_as_expression)
|
||||
if self._match_text_seq("DEFINE")
|
||||
else None
|
||||
)
|
||||
|
@ -3124,6 +3116,18 @@ class Parser(metaclass=_Parser):
|
|||
|
||||
return self.expression(exp.Connect, start=start, connect=connect)
|
||||
|
||||
def _parse_name_as_expression(self) -> exp.Alias:
|
||||
return self.expression(
|
||||
exp.Alias,
|
||||
alias=self._parse_id_var(any_token=True),
|
||||
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
|
||||
)
|
||||
|
||||
def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]:
|
||||
if self._match_text_seq("INTERPOLATE"):
|
||||
return self._parse_wrapped_csv(self._parse_name_as_expression)
|
||||
return None
|
||||
|
||||
def _parse_order(
|
||||
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
|
||||
) -> t.Optional[exp.Expression]:
|
||||
|
@ -3131,7 +3135,10 @@ class Parser(metaclass=_Parser):
|
|||
return this
|
||||
|
||||
return self.expression(
|
||||
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
|
||||
exp.Order,
|
||||
this=this,
|
||||
expressions=self._parse_csv(self._parse_ordered),
|
||||
interpolate=self._parse_interpolate(),
|
||||
)
|
||||
|
||||
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
|
||||
|
@ -3161,7 +3168,21 @@ class Parser(metaclass=_Parser):
|
|||
):
|
||||
nulls_first = True
|
||||
|
||||
return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
|
||||
if self._match_text_seq("WITH", "FILL"):
|
||||
with_fill = self.expression(
|
||||
exp.WithFill,
|
||||
**{ # type: ignore
|
||||
"from": self._match(TokenType.FROM) and self._parse_bitwise(),
|
||||
"to": self._match_text_seq("TO") and self._parse_bitwise(),
|
||||
"step": self._match_text_seq("STEP") and self._parse_bitwise(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
with_fill = None
|
||||
|
||||
return self.expression(
|
||||
exp.Ordered, this=this, desc=desc, nulls_first=nulls_first, with_fill=with_fill
|
||||
)
|
||||
|
||||
def _parse_limit(
|
||||
self, this: t.Optional[exp.Expression] = None, top: bool = False
|
||||
|
@ -3253,28 +3274,40 @@ class Parser(metaclass=_Parser):
|
|||
return locks
|
||||
|
||||
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
|
||||
if not self._match_set(self.SET_OPERATIONS):
|
||||
return this
|
||||
while this and self._match_set(self.SET_OPERATIONS):
|
||||
token_type = self._prev.token_type
|
||||
|
||||
token_type = self._prev.token_type
|
||||
if token_type == TokenType.UNION:
|
||||
operation = exp.Union
|
||||
elif token_type == TokenType.EXCEPT:
|
||||
operation = exp.Except
|
||||
else:
|
||||
operation = exp.Intersect
|
||||
|
||||
if token_type == TokenType.UNION:
|
||||
expression = exp.Union
|
||||
elif token_type == TokenType.EXCEPT:
|
||||
expression = exp.Except
|
||||
else:
|
||||
expression = exp.Intersect
|
||||
comments = self._prev.comments
|
||||
distinct = self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL)
|
||||
by_name = self._match_text_seq("BY", "NAME")
|
||||
expression = self._parse_select(nested=True, parse_set_operation=False)
|
||||
|
||||
return self.expression(
|
||||
expression,
|
||||
comments=self._prev.comments,
|
||||
this=this,
|
||||
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
|
||||
by_name=self._match_text_seq("BY", "NAME"),
|
||||
expression=self._parse_set_operations(
|
||||
self._parse_select(nested=True, parse_set_operation=False)
|
||||
),
|
||||
)
|
||||
this = self.expression(
|
||||
operation,
|
||||
comments=comments,
|
||||
this=this,
|
||||
distinct=distinct,
|
||||
by_name=by_name,
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
|
||||
expression = this.expression
|
||||
|
||||
if expression:
|
||||
for arg in self.UNION_MODIFIERS:
|
||||
expr = expression.args.get(arg)
|
||||
if expr:
|
||||
this.set(arg, expr.pop())
|
||||
|
||||
return this
|
||||
|
||||
def _parse_expression(self) -> t.Optional[exp.Expression]:
|
||||
return self._parse_alias(self._parse_conjunction())
|
||||
|
@ -3595,7 +3628,7 @@ class Parser(metaclass=_Parser):
|
|||
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
|
||||
)
|
||||
else:
|
||||
this = self.expression(exp.Interval, unit=unit)
|
||||
this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit))
|
||||
|
||||
if maybe_func and check_func:
|
||||
index2 = self._index
|
||||
|
@ -4891,8 +4924,8 @@ class Parser(metaclass=_Parser):
|
|||
return self.expression(exp.Var, this=self._prev.text)
|
||||
return self._parse_placeholder()
|
||||
|
||||
def _advance_any(self) -> t.Optional[Token]:
|
||||
if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
|
||||
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
|
||||
if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS):
|
||||
self._advance()
|
||||
return self._prev
|
||||
return None
|
||||
|
|
|
@ -425,16 +425,27 @@ class SetOperation(Step):
|
|||
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
|
||||
) -> Step:
|
||||
assert isinstance(expression, exp.Union)
|
||||
|
||||
left = Step.from_expression(expression.left, ctes)
|
||||
# SELECT 1 UNION SELECT 2 <-- these subqueries don't have names
|
||||
left.name = left.name or "left"
|
||||
right = Step.from_expression(expression.right, ctes)
|
||||
right.name = right.name or "right"
|
||||
step = cls(
|
||||
op=expression.__class__,
|
||||
left=left.name,
|
||||
right=right.name,
|
||||
distinct=bool(expression.args.get("distinct")),
|
||||
)
|
||||
|
||||
step.add_dependency(left)
|
||||
step.add_dependency(right)
|
||||
|
||||
limit = expression.args.get("limit")
|
||||
|
||||
if limit:
|
||||
step.limit = int(limit.text("expression"))
|
||||
|
||||
return step
|
||||
|
||||
def _to_s(self, indent: str) -> t.List[str]:
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
from enum import auto
|
||||
|
||||
from sqlglot.errors import TokenError
|
||||
from sqlglot.errors import SqlglotError, TokenError
|
||||
from sqlglot.helper import AutoName
|
||||
from sqlglot.trie import TrieResult, in_trie, new_trie
|
||||
|
||||
|
@ -11,6 +12,19 @@ if t.TYPE_CHECKING:
|
|||
from sqlglot.dialects.dialect import DialectType
|
||||
|
||||
|
||||
try:
|
||||
from sqlglotrs import ( # type: ignore
|
||||
Tokenizer as RsTokenizer,
|
||||
TokenizerDialectSettings as RsTokenizerDialectSettings,
|
||||
TokenizerSettings as RsTokenizerSettings,
|
||||
TokenTypeSettings as RsTokenTypeSettings,
|
||||
)
|
||||
|
||||
USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1"
|
||||
except ImportError:
|
||||
USE_RS_TOKENIZER = False
|
||||
|
||||
|
||||
class TokenType(AutoName):
|
||||
L_PAREN = auto()
|
||||
R_PAREN = auto()
|
||||
|
@ -83,6 +97,7 @@ class TokenType(AutoName):
|
|||
NATIONAL_STRING = auto()
|
||||
RAW_STRING = auto()
|
||||
HEREDOC_STRING = auto()
|
||||
UNICODE_STRING = auto()
|
||||
|
||||
# types
|
||||
BIT = auto()
|
||||
|
@ -347,6 +362,10 @@ class TokenType(AutoName):
|
|||
TIMESTAMP_SNAPSHOT = auto()
|
||||
|
||||
|
||||
_ALL_TOKEN_TYPES = list(TokenType)
|
||||
_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)}
|
||||
|
||||
|
||||
class Token:
|
||||
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
|
||||
|
||||
|
@ -432,6 +451,7 @@ class _Tokenizer(type):
|
|||
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
|
||||
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
|
||||
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
|
||||
**_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
|
||||
}
|
||||
|
||||
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||
|
@ -455,6 +475,46 @@ class _Tokenizer(type):
|
|||
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
|
||||
)
|
||||
|
||||
if USE_RS_TOKENIZER:
|
||||
settings = RsTokenizerSettings(
|
||||
white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()},
|
||||
single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()},
|
||||
keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()},
|
||||
numeric_literals=klass.NUMERIC_LITERALS,
|
||||
identifiers=klass._IDENTIFIERS,
|
||||
identifier_escapes=klass._IDENTIFIER_ESCAPES,
|
||||
string_escapes=klass._STRING_ESCAPES,
|
||||
quotes=klass._QUOTES,
|
||||
format_strings={
|
||||
k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
|
||||
for k, (v1, v2) in klass._FORMAT_STRINGS.items()
|
||||
},
|
||||
has_bit_strings=bool(klass.BIT_STRINGS),
|
||||
has_hex_strings=bool(klass.HEX_STRINGS),
|
||||
comments=klass._COMMENTS,
|
||||
var_single_tokens=klass.VAR_SINGLE_TOKENS,
|
||||
commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS},
|
||||
command_prefix_tokens={
|
||||
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
|
||||
},
|
||||
)
|
||||
token_types = RsTokenTypeSettings(
|
||||
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
|
||||
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
|
||||
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
|
||||
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
|
||||
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
|
||||
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
|
||||
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
|
||||
parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER],
|
||||
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
|
||||
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
|
||||
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
|
||||
)
|
||||
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
|
||||
else:
|
||||
klass._RS_TOKENIZER = None
|
||||
|
||||
return klass
|
||||
|
||||
|
||||
|
@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = []
|
||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||
IDENTIFIER_ESCAPES = ['"']
|
||||
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
|
||||
|
@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
_QUOTES: t.Dict[str, str] = {}
|
||||
_STRING_ESCAPES: t.Set[str] = set()
|
||||
_KEYWORD_TRIE: t.Dict = {}
|
||||
_RS_TOKENIZER: t.Optional[t.Any] = None
|
||||
|
||||
KEYWORDS: t.Dict[str, TokenType] = {
|
||||
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
|
||||
|
@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
# handle numeric literals like in hive (3L = BIGINT)
|
||||
NUMERIC_LITERALS: t.Dict[str, str] = {}
|
||||
ENCODE: t.Optional[str] = None
|
||||
|
||||
COMMENTS = ["--", ("/*", "*/")]
|
||||
|
||||
|
@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
"_end",
|
||||
"_peek",
|
||||
"_prev_token_line",
|
||||
"_rs_dialect_settings",
|
||||
)
|
||||
|
||||
def __init__(self, dialect: DialectType = None) -> None:
|
||||
from sqlglot.dialects import Dialect
|
||||
|
||||
self.dialect = Dialect.get_or_raise(dialect)
|
||||
|
||||
if USE_RS_TOKENIZER:
|
||||
self._rs_dialect_settings = RsTokenizerDialectSettings(
|
||||
escape_sequences=self.dialect.ESCAPE_SEQUENCES,
|
||||
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
|
||||
def tokenize(self, sql: str) -> t.List[Token]:
|
||||
"""Returns a list of tokens corresponding to the SQL string `sql`."""
|
||||
if USE_RS_TOKENIZER:
|
||||
return self.tokenize_rs(sql)
|
||||
|
||||
self.reset()
|
||||
self.sql = sql
|
||||
self.size = len(sql)
|
||||
|
@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
# Ensures we don't count an extra line if we get a \r\n line break sequence
|
||||
if self._char == "\r" and self._peek == "\n":
|
||||
i = 2
|
||||
self._start += 1
|
||||
|
||||
self._col = 1
|
||||
self._line += 1
|
||||
|
@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
raise TokenError(
|
||||
f"Numeric string contains invalid characters from {self._line}:{self._start}"
|
||||
)
|
||||
else:
|
||||
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
|
||||
|
||||
self._add(token_type, text)
|
||||
return True
|
||||
|
@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer):
|
|||
text += self.sql[current : self._current - 1]
|
||||
|
||||
return text
|
||||
|
||||
def tokenize_rs(self, sql: str) -> t.List[Token]:
|
||||
if not self._RS_TOKENIZER:
|
||||
raise SqlglotError("Rust tokenizer is not available")
|
||||
|
||||
try:
|
||||
tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings)
|
||||
for token in tokens:
|
||||
token.token_type = _ALL_TOKEN_TYPES[token.token_type_index]
|
||||
return tokens
|
||||
except Exception as e:
|
||||
raise TokenError(str(e))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue