1
0
Fork 0

Merging upstream version 20.3.0.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 21:17:51 +01:00
parent 2945bcc4f7
commit 4d9376ba93
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
132 changed files with 55125 additions and 51576 deletions

View file

@ -12,7 +12,7 @@ classes as needed.
### Implementing a custom Dialect
Consider the following example:
Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot:
```python
from sqlglot import exp
@ -23,9 +23,10 @@ from sqlglot.tokens import Tokenizer, TokenType
class Custom(Dialect):
class Tokenizer(Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes
IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks
# Associates certain meaningful words with tokens that capture their intent
KEYWORDS = {
**Tokenizer.KEYWORDS,
"INT64": TokenType.BIGINT,
@ -33,8 +34,12 @@ class Custom(Dialect):
}
class Generator(Generator):
TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"}
# Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL
TRANSFORMS = {
exp.Array: lambda self, e: f"[{self.expressions(e)}]",
}
# Specifies how AST nodes representing data types should be converted into SQL
TYPE_MAPPING = {
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
@ -48,10 +53,9 @@ class Custom(Dialect):
}
```
This is a typical example of adding a new dialect implementation in SQLGlot: we specify its identifier and string
delimiters, as well as what tokens it uses for its types and how they're associated with SQLGlot types. Since
the `Expression` classes are common for each dialect supported in SQLGlot, we may also need to override the generation
logic for some expressions; this is usually done by adding new entries to the `TRANSFORMS` mapping.
The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different
specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing
dialect implementations in order to understand how their various components can be modified, depending on the use-case.
----
"""

View file

@ -215,6 +215,7 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
class BigQuery(Dialect):
WEEK_OFFSET = -1
UNNEST_COLUMN_ONLY = True
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
@ -437,11 +438,7 @@ class BigQuery(Dialect):
elif isinstance(this, exp.Literal):
table_name = this.name
if (
self._curr
and self._prev.end == self._curr.start - 1
and self._parse_var(any_token=True)
):
if self._is_connected() and self._parse_var(any_token=True):
table_name += self._prev.text
this = exp.Identifier(this=table_name, quoted=True)

View file

@ -83,6 +83,11 @@ class ClickHouse(Dialect):
}
class Parser(parser.Parser):
# Tested in ClickHouse's playground, it seems that the following two queries do the same thing
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,

View file

@ -21,11 +21,14 @@ DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
DIALECT = ""
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
DORIS = "doris"
DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
@ -43,16 +46,22 @@ class Dialects(str, Enum):
TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
Doris = "doris"
class NormalizationStrategy(str, AutoName):
"""Specifies the strategy according to which identifiers should be normalized."""
LOWERCASE = auto() # Unquoted identifiers are lowercased
UPPERCASE = auto() # Unquoted identifiers are uppercased
CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
LOWERCASE = auto()
"""Unquoted identifiers are lowercased."""
UPPERCASE = auto()
"""Unquoted identifiers are uppercased."""
CASE_SENSITIVE = auto()
"""Always case-sensitive, regardless of quotes."""
CASE_INSENSITIVE = auto()
"""Always case-insensitive, regardless of quotes."""
class _Dialect(type):
@ -117,6 +126,7 @@ class _Dialect(type):
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
@ -131,74 +141,84 @@ class _Dialect(type):
class Dialect(metaclass=_Dialect):
# Determines the base index offset for arrays
INDEX_OFFSET = 0
"""Determines the base index offset for arrays."""
WEEK_OFFSET = 0
"""Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
# If true unnest table aliases are considered only as column aliases
UNNEST_COLUMN_ONLY = False
"""Determines whether or not `UNNEST` table aliases are treated as column aliases."""
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
"""Determines whether or not the table alias comes after tablesample."""
# Specifies the strategy according to which identifiers should be normalized.
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
"""Determines whether or not an unquoted identifier can start with a digit."""
# Determines whether or not the DPIPE token ('||') is a string concatenation operator
DPIPE_IS_STRING_CONCAT = True
"""Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
# Determines whether or not CONCAT's arguments must be strings
STRICT_STRING_CONCAT = False
"""Determines whether or not `CONCAT`'s arguments must be strings."""
# Determines whether or not user-defined data types are supported
SUPPORTS_USER_DEFINED_TYPES = True
"""Determines whether or not user-defined data types are supported."""
# Determines whether or not SEMI/ANTI JOINs are supported
SUPPORTS_SEMI_ANTI_JOIN = True
"""Determines whether or not `SEMI` or `ANTI` joins are supported."""
# Determines how function names are going to be normalized
NORMALIZE_FUNCTIONS: bool | str = "upper"
"""Determines how function names are going to be normalized."""
# Determines whether the base comes first in the LOG function
LOG_BASE_FIRST = True
"""Determines whether the base comes first in the `LOG` function."""
# Indicates the default null ordering method to use if not explicitly set
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
"""
Indicates the default `NULL` ordering method to use if not explicitly set.
Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
"""
# Whether the behavior of a / b depends on the types of a and b.
# False means a / b is always float division.
# True means a / b is integer division if both a and b are integers.
TYPED_DIVISION = False
"""
Whether the behavior of `a / b` depends on the types of `a` and `b`.
False means `a / b` is always float division.
True means `a / b` is integer division if both `a` and `b` are integers.
"""
# False means 1 / 0 throws an error.
# True means 1 / 0 returns null.
SAFE_DIVISION = False
"""Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
# A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
CONCAT_COALESCE = False
"""A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
# Custom time mappings in which the key represents dialect time format
# and the value represents a python time format
TIME_MAPPING: t.Dict[str, str] = {}
"""Associates this dialect's time formats with their equivalent Python `strftime` format."""
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
"""
Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`.
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
# Mapping of an unescaped escape sequence to the corresponding character
ESCAPE_SEQUENCES: t.Dict[str, str] = {}
"""Mapping of an unescaped escape sequence to the corresponding character."""
# Columns that are auto-generated by the engine corresponding to this dialect
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
"""
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from `SELECT *` queries.
"""
# --- Autofilled ---
@ -221,13 +241,15 @@ class Dialect(metaclass=_Dialect):
IDENTIFIER_START = '"'
IDENTIFIER_END = '"'
# Delimiters for bit, hex and byte literals
# Delimiters for bit, hex, byte and unicode literals
BIT_START: t.Optional[str] = None
BIT_END: t.Optional[str] = None
HEX_START: t.Optional[str] = None
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None
@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
@ -275,6 +297,7 @@ class Dialect(metaclass=_Dialect):
def format_time(
cls, expression: t.Optional[str | exp.Expression]
) -> t.Optional[exp.Expression]:
"""Converts a time format in this dialect to its equivalent Python `strftime` format."""
if isinstance(expression, str):
return exp.Literal.string(
# the time formats are quoted
@ -306,9 +329,9 @@ class Dialect(metaclass=_Dialect):
"""
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO would be resolved as foo in Postgres, because it
For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are
@ -356,8 +379,8 @@ class Dialect(metaclass=_Dialect):
Args:
text: The text to check.
identify:
"always" or `True`: Always returns true.
"safe": True if the identifier is case-insensitive.
`"always"` or `True`: Always returns `True`.
`"safe"`: Only returns `True` if the identifier is case-insensitive.
Returns:
Whether or not the given text can be identified.
@ -371,6 +394,14 @@ class Dialect(metaclass=_Dialect):
return False
def quote_identifier(self, expression: E, identify: bool = True) -> E:
"""
Adds quotes to a given identifier.
Args:
expression: The expression of interest. If it's not an `Identifier`, this method is a no-op.
identify: If set to `False`, the quotes will only be added if the identifier is deemed
"unsafe", with respect to its characters and this dialect's normalization strategy.
"""
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(

View file

@ -81,7 +81,6 @@ class Drill(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
STRICT_CAST = False

View file

@ -84,11 +84,35 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _parse_make_timestamp(args: t.List) -> exp.Expression:
if len(args) == 1:
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
return exp.TimestampFromParts(
year=seq_get(args, 0),
month=seq_get(args, 1),
day=seq_get(args, 2),
hour=seq_get(args, 3),
min=seq_get(args, 4),
sec=seq_get(args, 5),
)
def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e.expressions[0]) if isinstance(e, exp.Bracket) else self.sql(e, 'expression')}"
for e in expression.expressions
]
args: t.List[str] = []
for expr in expression.expressions:
if isinstance(expr, exp.Alias):
key = expr.alias
value = expr.this
else:
key = expr.name or expr.this.name
if isinstance(expr, exp.Bracket):
value = expr.expressions[0]
else:
value = expr.expression
args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
return f"{{{', '.join(args)}}}"
@ -189,9 +213,7 @@ class DuckDB(Dialect):
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"MAKE_TIMESTAMP": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
),
"MAKE_TIMESTAMP": _parse_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
@ -339,6 +361,7 @@ class DuckDB(Dialect):
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,

View file

@ -240,7 +240,6 @@ class Hive(Dialect):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,

View file

@ -650,7 +650,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[

View file

@ -277,6 +277,7 @@ class Postgres(Dialect):
"CONSTRAINT TRIGGER": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"EXEC": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
"MONEY": TokenType.MONEY,

View file

@ -186,6 +186,27 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
return ""
def _to_int(expression: exp.Expression) -> exp.Expression:
if not expression.type:
from sqlglot.optimizer.annotate_types import annotate_types
annotate_types(expression)
if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES:
return exp.cast(expression, to=exp.DataType.Type.BIGINT)
return expression
def _parse_to_char(args: t.List) -> exp.TimeToStr:
fmt = seq_get(args, 1)
if isinstance(fmt, exp.Literal):
# We uppercase this to match Teradata's format mapping keys
fmt.set("this", fmt.this.upper())
# We use "teradata" on purpose here, because the time formats are different in Presto.
# See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char
return format_time_lambda(exp.TimeToStr, "teradata")(args)
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@ -201,6 +222,12 @@ class Presto(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
class Tokenizer(tokens.Tokenizer):
UNICODE_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES)
for prefix in ("U&", "u&")
]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"START": TokenType.BEGIN,
@ -253,8 +280,9 @@ class Presto(Dialect):
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_CHAR": _parse_to_char,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
@ -315,7 +343,12 @@ class Presto(Dialect):
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
_to_int(
e.expression,
),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
@ -325,7 +358,7 @@ class Presto(Dialect):
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "day"),
e.expression * -1,
_to_int(e.expression * -1),
e.this,
),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
@ -354,6 +387,7 @@ class Presto(Dialect):
exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.Select: transforms.preprocess(
[
transforms.eliminate_qualify,
@ -377,6 +411,7 @@ class Presto(Dialect):
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,

View file

@ -293,7 +293,6 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
"TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@ -369,36 +368,58 @@ class Snowflake(Dialect):
return lateral
def _parse_at_before(self, table: exp.Table) -> exp.Table:
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
index = self._index
if self._match_texts(("AT", "BEFORE")):
this = self._prev.text.upper()
kind = (
self._match(TokenType.L_PAREN)
and self._match_texts(self.HISTORICAL_DATA_KIND)
and self._prev.text.upper()
)
expression = self._match(TokenType.FARROW) and self._parse_bitwise()
if expression:
self._match_r_paren()
when = self.expression(
exp.HistoricalData, this=this, kind=kind, expression=expression
)
table.set("when", when)
else:
self._retreat(index)
return table
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
table: t.Optional[exp.Expression] = None
if self._match_text_seq("@"):
table_name = "@"
while self._curr:
self._advance()
table_name += self._prev.text
if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
break
while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
table_name += self._prev.text
table = exp.var(table_name)
elif self._match(TokenType.STRING, advance=False):
if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
elif self._match_text_seq("@", advance=False):
table = self._parse_location_path()
else:
table = None
if table:
file_format = None
pattern = None
if self._match_text_seq("(", "FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
if self._match_text_seq(",", "PATTERN", "=>"):
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
self._match_r_paren()
else:
break
return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
self._match(TokenType.COMMA)
return super()._parse_table_parts(schema=schema)
table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
else:
table = super()._parse_table_parts(schema=schema)
return self._parse_at_before(table)
def _parse_id_var(
self,
@ -438,17 +459,17 @@ class Snowflake(Dialect):
def _parse_location(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
return self.expression(exp.LocationProperty, this=self._parse_location_path())
parts = [self._parse_var(any_token=True)]
def _parse_location_path(self) -> exp.Var:
parts = [self._advance_any(ignore_reserved=True)]
while self._match(TokenType.SLASH):
if self._curr and self._prev.end + 1 == self._curr.start:
parts.append(self._parse_var(any_token=True))
else:
parts.append(exp.Var(this=""))
return self.expression(
exp.LocationProperty, this=exp.var("/".join(str(p) for p in parts))
)
# We avoid consuming a comma token because external tables like @foo and @bar
# can be joined in a query with a comma separator.
while self._is_connected() and not self._match(TokenType.COMMA, advance=False):
parts.append(self._advance_any(ignore_reserved=True))
return exp.var("".join(part.text for part in parts if part))
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
@ -562,6 +583,7 @@ class Snowflake(Dialect):
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
),
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToArray: rename_func("TO_ARRAY"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),

View file

@ -12,22 +12,30 @@ class Teradata(Dialect):
TYPED_DIVISION = True
TIME_MAPPING = {
"Y": "%Y",
"YYYY": "%Y",
"YY": "%y",
"MMMM": "%B",
"MMM": "%b",
"DD": "%d",
"D": "%-d",
"HH": "%H",
"H": "%-H",
"MM": "%M",
"Y4": "%Y",
"YYYY": "%Y",
"M4": "%B",
"M3": "%b",
"M": "%-M",
"SS": "%S",
"MI": "%M",
"MM": "%m",
"MMM": "%b",
"MMMM": "%B",
"D": "%-d",
"DD": "%d",
"D3": "%j",
"DDD": "%j",
"H": "%-H",
"HH": "%H",
"HH24": "%H",
"S": "%-S",
"SS": "%S",
"SSSSSS": "%f",
"E": "%a",
"EE": "%a",
"E3": "%a",
"E4": "%A",
"EEE": "%a",
"EEEE": "%A",
}

View file

@ -701,6 +701,13 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
def set_operation(self, expression: exp.Union, op: str) -> str:
limit = expression.args.get("limit")
if limit:
return self.sql(expression.limit(limit.pop(), copy=False))
return super().set_operation(expression, op)
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):

View file

@ -343,6 +343,9 @@ class PythonExecutor:
else:
sink.rows = left.rows + right.rows
if not math.isinf(step.limit):
sink.rows = sink.rows[0 : step.limit]
return self.context({step.name: sink})

View file

@ -1105,14 +1105,7 @@ class Create(DDL):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
class Clone(Expression):
arg_types = {
"this": True,
"when": False,
"kind": False,
"shallow": False,
"expression": False,
"copy": False,
}
arg_types = {"this": True, "shallow": False, "copy": False}
class Describe(Expression):
@ -1213,6 +1206,10 @@ class RawString(Condition):
pass
class UnicodeString(Condition):
arg_types = {"this": True, "escape": False}
class Column(Condition):
arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
@ -1967,7 +1964,12 @@ class Offset(Expression):
class Order(Expression):
arg_types = {"this": False, "expressions": True}
arg_types = {"this": False, "expressions": True, "interpolate": False}
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
class WithFill(Expression):
arg_types = {"from": False, "to": False, "step": False}
# hive specific sorts
@ -1985,7 +1987,7 @@ class Sort(Order):
class Ordered(Expression):
arg_types = {"this": True, "desc": False, "nulls_first": True}
arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False}
class Property(Expression):
@ -2522,6 +2524,11 @@ class IndexTableHint(Expression):
arg_types = {"this": True, "expressions": False, "target": False}
# https://docs.snowflake.com/en/sql-reference/constructs/at-before
class HistoricalData(Expression):
arg_types = {"this": True, "kind": True, "expression": True}
class Table(Expression):
arg_types = {
"this": True,
@ -2538,6 +2545,7 @@ class Table(Expression):
"pattern": False,
"index": False,
"ordinality": False,
"when": False,
}
@property
@ -4310,6 +4318,11 @@ class Array(Func):
is_var_len_args = True
# https://docs.snowflake.com/en/sql-reference/functions/to_array
class ToArray(Func):
pass
# https://docs.snowflake.com/en/sql-reference/functions/to_char
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html
class ToChar(Func):
@ -5233,6 +5246,19 @@ class UnixToTimeStr(Func):
pass
class TimestampFromParts(Func):
"""Constructs a timestamp given its constituent parts."""
arg_types = {
"year": True,
"month": True,
"day": True,
"hour": True,
"min": True,
"sec": True,
}
class Upper(Func):
_sql_names = ["UPPER", "UCASE"]

View file

@ -862,15 +862,7 @@ class Generator:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE"
this = f"{shallow}{keyword} {this}"
when = self.sql(expression, "when")
if when:
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
return f"{this} {when} ({kind} => {expr})"
return this
return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
@ -923,6 +915,14 @@ class Generator:
return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}"
return this
def unicodestring_sql(self, expression: exp.UnicodeString) -> str:
this = self.sql(expression, "this")
if self.dialect.UNICODE_START:
escape = self.sql(expression, "escape")
escape = f" UESCAPE {escape}" if escape else ""
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}"
@ -1400,6 +1400,12 @@ class Generator:
target = f" FOR {target}" if target else ""
return f"{this}{target} ({self.expressions(expression, flat=True)})"
def historicaldata_sql(self, expression: exp.HistoricalData) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
expr = self.sql(expression, "expression")
return f"{this} ({kind} => {expr})"
def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
self.sql(part)
@ -1436,6 +1442,10 @@ class Generator:
ordinality = f" WITH ORDINALITY{alias}"
alias = ""
when = self.sql(expression, "when")
if when:
table = f"{table} {when}"
return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
@ -1784,7 +1794,24 @@ class Generator:
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
interpolated_values = [
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
for named_expression in expression.args.get("interpolate") or []
]
interpolate = (
f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else ""
)
return f"{order}{interpolate}"
def withfill_sql(self, expression: exp.WithFill) -> str:
from_sql = self.sql(expression, "from")
from_sql = f" FROM {from_sql}" if from_sql else ""
to_sql = self.sql(expression, "to")
to_sql = f" TO {to_sql}" if to_sql else ""
step_sql = self.sql(expression, "step")
step_sql = f" STEP {step_sql}" if step_sql else ""
return f"WITH FILL{from_sql}{to_sql}{step_sql}"
def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
@ -1826,7 +1853,10 @@ class Generator:
this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
nulls_sort_change = ""
return f"{this}{sort_order}{nulls_sort_change}"
with_fill = self.sql(expression, "with_fill")
with_fill = f" {with_fill}" if with_fill else ""
return f"{this}{sort_order}{nulls_sort_change}{with_fill}"
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
@ -3048,11 +3078,24 @@ class Generator:
def operator_sql(self, expression: exp.Operator) -> str:
return self.binary(expression, f"OPERATOR({self.sql(expression, 'operator')})")
def toarray_sql(self, expression: exp.ToArray) -> str:
arg = expression.this
if not arg.type:
from sqlglot.optimizer.annotate_types import annotate_types
arg = annotate_types(arg)
if arg.is_type(exp.DataType.Type.ARRAY):
return self.sql(arg)
cond_for_null = arg.is_(exp.null())
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
expression = simplify(expression)
expression = simplify(expression, dialect=self.dialect)
return expression

View file

@ -95,9 +95,6 @@ def eliminate_subqueries(expression):
def _eliminate(scope, existing_ctes, taken):
if scope.is_union:
return _eliminate_union(scope, existing_ctes, taken)
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)
@ -105,36 +102,6 @@ def _eliminate(scope, existing_ctes, taken):
return _eliminate_cte(scope, existing_ctes, taken)
def _eliminate_union(scope, existing_ctes, taken):
duplicate_cte_alias = existing_ctes.get(scope.expression)
alias = duplicate_cte_alias or find_new_name(taken=taken, base="cte")
taken[alias] = scope
# Try to maintain the selections
expressions = scope.expression.selects
selects = [
exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False)
for e in expressions
if e.alias_or_name
]
# If not all selections have an alias, just select *
if len(selects) != len(expressions):
selects = ["*"]
scope.expression.replace(
exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False))
)
if not duplicate_cte_alias:
existing_ctes[scope.expression] = alias
return exp.CTE(
this=scope.expression,
alias=exp.TableAlias(this=exp.to_identifier(alias)),
)
def _eliminate_derived_table(scope, existing_ctes, taken):
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery

View file

@ -174,6 +174,22 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
for col in inner_projections[selection].find_all(exp.Column)
)
def _is_recursive():
# Recursive CTEs look like this:
# WITH RECURSIVE cte AS (
# SELECT * FROM x <-- inner scope
# UNION ALL
# SELECT * FROM cte <-- outer scope
# )
cte = inner_scope.expression.parent
node = outer_scope.expression.parent
while node:
if node is cte:
return True
node = node.parent
return False
return (
isinstance(outer_scope.expression, exp.Select)
and not outer_scope.expression.is_star
@ -197,6 +213,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
)
and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
and not _is_recursive()
)

View file

@ -4,7 +4,7 @@ from sqlglot.optimizer.scope import build_scope, find_in_scope
from sqlglot.optimizer.simplify import simplify
def pushdown_predicates(expression):
def pushdown_predicates(expression, dialect=None):
"""
Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
@ -36,7 +36,7 @@ def pushdown_predicates(expression):
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)
pushdown(where.this, selected_sources, scope_ref_count, dialect)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@ -44,17 +44,20 @@ def pushdown_predicates(expression):
name = join.alias_or_name
if name in scope.selected_sources:
pushdown(
join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count
join.args.get("on"),
{name: scope.selected_sources[name]},
scope_ref_count,
dialect,
)
return expression
def pushdown(condition, sources, scope_ref_count):
def pushdown(condition, sources, scope_ref_count, dialect):
if not condition:
return
condition = condition.replace(simplify(condition))
condition = condition.replace(simplify(condition, dialect=dialect))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
predicates = list(

View file

@ -37,6 +37,7 @@ class Scope:
For example:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@ -61,11 +62,14 @@ class Scope:
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
cte_sources=None,
):
self.expression = expression
self.sources = sources or {}
self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
self.lateral_sources = lateral_sources or {}
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
@ -92,13 +96,17 @@ class Scope:
self._pivots = None
self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
def branch(
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
):
"""Branch from the current scope to a new, inner scope"""
return Scope(
expression=expression.unnest(),
sources={**self.cte_sources, **(chain_sources or {})},
sources=sources.copy() if sources else None,
parent=self,
scope_type=scope_type,
cte_sources={**self.cte_sources, **(cte_sources or {})},
lateral_sources=lateral_sources.copy() if lateral_sources else None,
**kwargs,
)
@ -305,20 +313,6 @@ class Scope:
return self._references
@property
def cte_sources(self):
"""
Sources that are CTEs.
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
return {
alias: scope
for alias, scope in self.sources.items()
if isinstance(scope, Scope) and scope.is_cte
}
@property
def external_columns(self):
"""
@ -515,7 +509,10 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
if scope.is_root:
yield from _traverse_select(scope)
else:
yield from _traverse_subqueries(scope)
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
@ -572,7 +569,7 @@ def _traverse_ctes(scope):
for child_scope in _traverse_scope(
scope.branch(
cte.this,
chain_sources=sources,
cte_sources=sources,
outer_column_list=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
@ -584,12 +581,14 @@ def _traverse_ctes(scope):
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
child_scope.cte_sources[alias] = recursive_scope
# append the final child_scope yielded
if child_scope:
scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
scope.cte_sources.update(sources)
def _is_derived_table(expression: exp.Subquery) -> bool:
@ -725,7 +724,7 @@ def _traverse_ddl(scope):
yield from _traverse_ctes(scope)
query_scope = scope.branch(
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources
)
query_scope._collect()
query_scope._ctes = scope.ctes + query_scope._ctes

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import functools
import itertools
@ -6,10 +8,17 @@ from collections import deque
from decimal import Decimal
import sqlglot
from sqlglot import exp
from sqlglot import Dialect, exp
from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
]
# Final means that an expression should not be simplified
FINAL = "final"
@ -18,7 +27,9 @@ class UnsupportedUnit(Exception):
pass
def simplify(expression, constant_propagation=False):
def simplify(
expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
):
"""
Rewrite sqlglot AST to simplify expressions.
@ -36,15 +47,18 @@ def simplify(expression, constant_propagation=False):
sqlglot.Expression: simplified expression
"""
dialect = Dialect.get_or_raise(dialect)
# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
for group in expression.find_all(exp.Group):
select = group.parent
assert select
groups = set(group.expressions)
group.meta[FINAL] = True
for e in select.selects:
for e in select.expressions:
for node, *_ in e.walk():
if node in groups:
e.meta[FINAL] = True
@ -84,7 +98,8 @@ def simplify(expression, constant_propagation=False):
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc_predicate(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
if root:
expression.replace(node)
@ -117,14 +132,30 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
This is done because comparison simplification is only done on lt/lte/gt/gte.
"""
if isinstance(expression, exp.Between):
return exp.and_(
negate = isinstance(expression.parent, exp.Not)
expression = exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
if negate:
expression = exp.paren(expression, copy=False)
return expression
COMPLEMENT_COMPARISONS = {
exp.LT: exp.GTE,
exp.GT: exp.LTE,
exp.LTE: exp.GT,
exp.GTE: exp.LT,
exp.EQ: exp.NEQ,
exp.NEQ: exp.EQ,
}
def simplify_not(expression):
"""
Demorgan's Law
@ -132,10 +163,15 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
if is_null(expression.this):
this = expression.this
if is_null(this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if this.__class__ in COMPLEMENT_COMPARISONS:
return COMPLEMENT_COMPARISONS[this.__class__](
this=this.this, expression=this.expression
)
if isinstance(this, exp.Paren):
condition = this.unnest()
if isinstance(condition, exp.And):
return exp.or_(
exp.not_(condition.left, copy=False),
@ -150,14 +186,14 @@ def simplify_not(expression):
)
if is_null(condition):
return exp.null()
if always_true(expression.this):
if always_true(this):
return exp.false()
if is_false(expression.this):
if is_false(this):
return exp.true()
if isinstance(expression.this, exp.Not):
if isinstance(this, exp.Not):
# double negation
# NOT NOT x -> x
return expression.this.this
return this.this
return expression
@ -249,12 +285,6 @@ def _simplify_comparison(expression, left, right, or_=False):
except StopIteration:
return expression
# make sure the comparison is always of the form x > 1 instead of 1 < x
if left.__class__ in INVERSE_COMPARISONS and l == ll:
left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
if right.__class__ in INVERSE_COMPARISONS and r == rl:
right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
if l.is_number and r.is_number:
l = float(l.name)
r = float(r.name)
@ -397,13 +427,7 @@ def propagate_constants(expression, root=True):
# TODO: create a helper that can be used to detect nested literal expressions such
# as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
pass
elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
l, r = r, l
else:
continue
constant_mapping[l] = (id(l), r)
constant_mapping[l] = (id(l), r)
if constant_mapping:
for column in find_all_in_scope(expression, exp.Column):
@ -458,11 +482,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, COMPARISONS):
l, r = expression.left, expression.right
if l.__class__ in INVERSE_OPS:
pass
elif r.__class__ in INVERSE_OPS:
l, r = r, l
else:
if not l.__class__ in INVERSE_OPS:
return expression
if r.is_number:
@ -650,7 +670,7 @@ def simplify_coalesce(expression):
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
if _is_constant(other):
if _is_constant(arg):
break
else:
return expression
@ -752,7 +772,7 @@ def simplify_conditionals(expression):
DateRange = t.Tuple[datetime.date, datetime.date]
def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
"""
Get the date range for a DATE_TRUNC equality comparison:
@ -761,7 +781,7 @@ def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
Returns:
tuple of [min, max) or None if a value can never be equal to `date` for `unit`
"""
floor = date_floor(date, unit)
floor = date_floor(date, unit, dialect)
if date != floor:
# This will always be False, except for NULL values.
@ -780,9 +800,9 @@ def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Exp
def _datetrunc_eq(
left: exp.Expression, date: datetime.date, unit: str
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@ -790,9 +810,9 @@ def _datetrunc_eq(
def _datetrunc_neq(
left: exp.Expression, date: datetime.date, unit: str
left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
) -> t.Optional[exp.Expression]:
drange = _datetrunc_range(date, unit)
drange = _datetrunc_range(date, unit, dialect)
if not drange:
return None
@ -803,41 +823,39 @@ def _datetrunc_neq(
)
DateTruncBinaryTransform = t.Callable[
[exp.Expression, datetime.date, str], t.Optional[exp.Expression]
]
DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
exp.LT: lambda l, dt, u, d: l
< date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
exp.EQ: _datetrunc_eq,
exp.NEQ: _datetrunc_neq,
}
DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
return isinstance(left, DATETRUNCS) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
"""Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
comparison = expression.__class__
if comparison not in DATETRUNC_COMPARISONS:
if isinstance(expression, DATETRUNCS):
date = extract_date(expression.this)
if date and expression.unit:
return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
elif comparison not in DATETRUNC_COMPARISONS:
return expression
if isinstance(expression, exp.Binary):
l, r = expression.left, expression.right
if _is_datetrunc_predicate(l, r):
pass
elif _is_datetrunc_predicate(r, l):
comparison = INVERSE_COMPARISONS.get(comparison, comparison)
l, r = r, l
else:
if not _is_datetrunc_predicate(l, r):
return expression
l = t.cast(exp.DateTrunc, l)
@ -847,7 +865,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
if not date:
return expression
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
@ -861,7 +879,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
date = extract_date(r)
if not date:
return expression
drange = _datetrunc_range(date, unit)
drange = _datetrunc_range(date, unit, dialect)
if drange:
ranges.append(drange)
@ -875,6 +893,23 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
return expression
def sort_comparison(expression: exp.Expression) -> exp.Expression:
if expression.__class__ in COMPLEMENT_COMPARISONS:
l, r = expression.this, expression.expression
l_column = isinstance(l, exp.Column)
r_column = isinstance(r, exp.Column)
l_const = _is_constant(l)
r_const = _is_constant(r)
if (l_column and not r_column) or (r_const and not l_const):
return expression
if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
this=r, expression=l
)
return expression
# CROSS joins result in an empty table if the right table is empty.
# So we can only simplify certain types of joins to CROSS.
# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
@ -1034,7 +1069,7 @@ def interval(unit: str, n: int = 1):
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
if unit == "year":
return d.replace(month=1, day=1)
if unit == "quarter":
@ -1050,15 +1085,15 @@ def date_floor(d: datetime.date, unit: str) -> datetime.date:
return d.replace(month=d.month, day=1)
if unit == "week":
# Assuming week starts on Monday (0) and ends on Sunday (6)
return d - datetime.timedelta(days=d.weekday())
return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
if unit == "day":
return d
raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
floor = date_floor(d, unit)
def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
floor = date_floor(d, unit, dialect)
if floor == d:
return d

View file

@ -65,6 +65,8 @@ def unnest(select, parent_select, next_alias_name):
)
):
column = exp.Max(this=column)
elif not isinstance(select.parent, exp.Subquery):
return
_replace(select.parent, column)
parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)

View file

@ -568,6 +568,7 @@ class Parser(metaclass=_Parser):
exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.When: lambda self: seq_get(self._parse_when_matched(), 0),
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
@ -635,6 +636,11 @@ class Parser(metaclass=_Parser):
TokenType.HEREDOC_STRING: lambda self, token: self.expression(
exp.RawString, this=token.text
),
TokenType.UNICODE_STRING: lambda self, token: self.expression(
exp.UnicodeString,
this=token.text,
escape=self._match_text_seq("UESCAPE") and self._parse_string(),
),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@ -907,7 +913,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
CLONE_KEYWORDS = {"CLONE", "COPY"}
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}
@ -947,6 +953,10 @@ class Parser(metaclass=_Parser):
# Whether the TRIM function expects the characters to trim as its first argument
TRIM_PATTERN_FIRST = False
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
__slots__ = (
"error_level",
"error_message_context",
@ -1162,6 +1172,9 @@ class Parser(metaclass=_Parser):
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[start.start : end.end + 1]
def _is_connected(self) -> bool:
return self._prev and self._curr and self._prev.end + 1 == self._curr.start
def _advance(self, times: int = 1) -> None:
self._index += times
self._curr = seq_get(self._tokens, self._index)
@ -1404,23 +1417,8 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.CLONE_KEYWORDS):
copy = self._prev.text.lower() == "copy"
clone = self._parse_table(schema=True)
when = self._match_texts(("AT", "BEFORE")) and self._prev.text.upper()
clone_kind = (
self._match(TokenType.L_PAREN)
and self._match_texts(self.CLONE_KINDS)
and self._prev.text.upper()
)
clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise()
self._match(TokenType.R_PAREN)
clone = self.expression(
exp.Clone,
this=clone,
when=when,
kind=clone_kind,
shallow=shallow,
expression=clone_expression,
copy=copy,
exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy
)
return self.expression(
@ -2471,13 +2469,7 @@ class Parser(metaclass=_Parser):
pattern = None
define = (
self._parse_csv(
lambda: self.expression(
exp.Alias,
alias=self._parse_id_var(any_token=True),
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
)
)
self._parse_csv(self._parse_name_as_expression)
if self._match_text_seq("DEFINE")
else None
)
@ -3124,6 +3116,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Connect, start=start, connect=connect)
def _parse_name_as_expression(self) -> exp.Alias:
return self.expression(
exp.Alias,
alias=self._parse_id_var(any_token=True),
this=self._match(TokenType.ALIAS) and self._parse_conjunction(),
)
def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]:
if self._match_text_seq("INTERPOLATE"):
return self._parse_wrapped_csv(self._parse_name_as_expression)
return None
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
@ -3131,7 +3135,10 @@ class Parser(metaclass=_Parser):
return this
return self.expression(
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
exp.Order,
this=this,
expressions=self._parse_csv(self._parse_ordered),
interpolate=self._parse_interpolate(),
)
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
@ -3161,7 +3168,21 @@ class Parser(metaclass=_Parser):
):
nulls_first = True
return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
if self._match_text_seq("WITH", "FILL"):
with_fill = self.expression(
exp.WithFill,
**{ # type: ignore
"from": self._match(TokenType.FROM) and self._parse_bitwise(),
"to": self._match_text_seq("TO") and self._parse_bitwise(),
"step": self._match_text_seq("STEP") and self._parse_bitwise(),
},
)
else:
with_fill = None
return self.expression(
exp.Ordered, this=this, desc=desc, nulls_first=nulls_first, with_fill=with_fill
)
def _parse_limit(
self, this: t.Optional[exp.Expression] = None, top: bool = False
@ -3253,28 +3274,40 @@ class Parser(metaclass=_Parser):
return locks
def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
return this
while this and self._match_set(self.SET_OPERATIONS):
token_type = self._prev.token_type
token_type = self._prev.token_type
if token_type == TokenType.UNION:
operation = exp.Union
elif token_type == TokenType.EXCEPT:
operation = exp.Except
else:
operation = exp.Intersect
if token_type == TokenType.UNION:
expression = exp.Union
elif token_type == TokenType.EXCEPT:
expression = exp.Except
else:
expression = exp.Intersect
comments = self._prev.comments
distinct = self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL)
by_name = self._match_text_seq("BY", "NAME")
expression = self._parse_select(nested=True, parse_set_operation=False)
return self.expression(
expression,
comments=self._prev.comments,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
by_name=self._match_text_seq("BY", "NAME"),
expression=self._parse_set_operations(
self._parse_select(nested=True, parse_set_operation=False)
),
)
this = self.expression(
operation,
comments=comments,
this=this,
distinct=distinct,
by_name=by_name,
expression=expression,
)
if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
expression = this.expression
if expression:
for arg in self.UNION_MODIFIERS:
expr = expression.args.get(arg)
if expr:
this.set(arg, expr.pop())
return this
def _parse_expression(self) -> t.Optional[exp.Expression]:
return self._parse_alias(self._parse_conjunction())
@ -3595,7 +3628,7 @@ class Parser(metaclass=_Parser):
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
)
else:
this = self.expression(exp.Interval, unit=unit)
this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit))
if maybe_func and check_func:
index2 = self._index
@ -4891,8 +4924,8 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
def _advance_any(self) -> t.Optional[Token]:
if self._curr and self._curr.token_type not in self.RESERVED_TOKENS:
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS):
self._advance()
return self._prev
return None

View file

@ -425,16 +425,27 @@ class SetOperation(Step):
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> Step:
assert isinstance(expression, exp.Union)
left = Step.from_expression(expression.left, ctes)
# SELECT 1 UNION SELECT 2 <-- these subqueries don't have names
left.name = left.name or "left"
right = Step.from_expression(expression.right, ctes)
right.name = right.name or "right"
step = cls(
op=expression.__class__,
left=left.name,
right=right.name,
distinct=bool(expression.args.get("distinct")),
)
step.add_dependency(left)
step.add_dependency(right)
limit = expression.args.get("limit")
if limit:
step.limit = int(limit.text("expression"))
return step
def _to_s(self, indent: str) -> t.List[str]:

View file

@ -1,9 +1,10 @@
from __future__ import annotations
import os
import typing as t
from enum import auto
from sqlglot.errors import TokenError
from sqlglot.errors import SqlglotError, TokenError
from sqlglot.helper import AutoName
from sqlglot.trie import TrieResult, in_trie, new_trie
@ -11,6 +12,19 @@ if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
try:
from sqlglotrs import ( # type: ignore
Tokenizer as RsTokenizer,
TokenizerDialectSettings as RsTokenizerDialectSettings,
TokenizerSettings as RsTokenizerSettings,
TokenTypeSettings as RsTokenTypeSettings,
)
USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1"
except ImportError:
USE_RS_TOKENIZER = False
class TokenType(AutoName):
L_PAREN = auto()
R_PAREN = auto()
@ -83,6 +97,7 @@ class TokenType(AutoName):
NATIONAL_STRING = auto()
RAW_STRING = auto()
HEREDOC_STRING = auto()
UNICODE_STRING = auto()
# types
BIT = auto()
@ -347,6 +362,10 @@ class TokenType(AutoName):
TIMESTAMP_SNAPSHOT = auto()
_ALL_TOKEN_TYPES = list(TokenType)
_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)}
class Token:
__slots__ = ("token_type", "text", "line", "col", "start", "end", "comments")
@ -432,6 +451,7 @@ class _Tokenizer(type):
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
**_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
**_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS),
}
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
@ -455,6 +475,46 @@ class _Tokenizer(type):
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
if USE_RS_TOKENIZER:
settings = RsTokenizerSettings(
white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()},
single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()},
keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()},
numeric_literals=klass.NUMERIC_LITERALS,
identifiers=klass._IDENTIFIERS,
identifier_escapes=klass._IDENTIFIER_ESCAPES,
string_escapes=klass._STRING_ESCAPES,
quotes=klass._QUOTES,
format_strings={
k: (v1, _TOKEN_TYPE_TO_INDEX[v2])
for k, (v1, v2) in klass._FORMAT_STRINGS.items()
},
has_bit_strings=bool(klass.BIT_STRINGS),
has_hex_strings=bool(klass.HEX_STRINGS),
comments=klass._COMMENTS,
var_single_tokens=klass.VAR_SINGLE_TOKENS,
commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS},
command_prefix_tokens={
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
},
)
token_types = RsTokenTypeSettings(
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK],
dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON],
heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING],
hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING],
identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER],
number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER],
parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER],
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
)
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
else:
klass._RS_TOKENIZER = None
return klass
@ -499,6 +559,7 @@ class Tokenizer(metaclass=_Tokenizer):
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
@ -513,6 +574,7 @@ class Tokenizer(metaclass=_Tokenizer):
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
_KEYWORD_TRIE: t.Dict = {}
_RS_TOKENIZER: t.Optional[t.Any] = None
KEYWORDS: t.Dict[str, TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
@ -804,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer):
# handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
@ -822,12 +883,20 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"_rs_dialect_settings",
)
def __init__(self, dialect: DialectType = None) -> None:
from sqlglot.dialects import Dialect
self.dialect = Dialect.get_or_raise(dialect)
if USE_RS_TOKENIZER:
self._rs_dialect_settings = RsTokenizerDialectSettings(
escape_sequences=self.dialect.ESCAPE_SEQUENCES,
identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT,
)
self.reset()
def reset(self) -> None:
@ -847,6 +916,9 @@ class Tokenizer(metaclass=_Tokenizer):
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
if USE_RS_TOKENIZER:
return self.tokenize_rs(sql)
self.reset()
self.sql = sql
self.size = len(sql)
@ -910,6 +982,7 @@ class Tokenizer(metaclass=_Tokenizer):
# Ensures we don't count an extra line if we get a \r\n line break sequence
if self._char == "\r" and self._peek == "\n":
i = 2
self._start += 1
self._col = 1
self._line += 1
@ -1184,8 +1257,6 @@ class Tokenizer(metaclass=_Tokenizer):
raise TokenError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
else:
text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
self._add(token_type, text)
return True
@ -1254,3 +1325,15 @@ class Tokenizer(metaclass=_Tokenizer):
text += self.sql[current : self._current - 1]
return text
def tokenize_rs(self, sql: str) -> t.List[Token]:
if not self._RS_TOKENIZER:
raise SqlglotError("Rust tokenizer is not available")
try:
tokens = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings)
for token in tokens:
token.token_type = _ALL_TOKEN_TYPES[token.token_type_index]
return tokens
except Exception as e:
raise TokenError(str(e))