1
0
Fork 0

Merging upstream version 9.0.6.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 14:51:47 +01:00
parent e369f04a93
commit 69b4fb4368
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
31 changed files with 694 additions and 196 deletions

View file

@ -1,3 +1,5 @@
"""## Python SQL parser, transpiler and optimizer."""
from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
@ -24,7 +26,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
__version__ = "9.0.3"
__version__ = "9.0.6"
pretty = False

View file

@ -1,5 +1,6 @@
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive

View file

@ -0,0 +1,21 @@
from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
class Databricks(Spark):
class Parser(Spark.Parser):
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
"DATEADD": parse_date_delta(exp.DateAdd),
"DATE_ADD": parse_date_delta(exp.DateAdd),
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}

View file

@ -28,6 +28,7 @@ class Dialects(str, Enum):
TABLEAU = "tableau"
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
class _Dialect(type):
@ -331,3 +332,15 @@ def create_with_partitions_sql(self, expression):
expression.set("this", schema)
return self.create_sql(expression)
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
this = list_get(args, 2) if unit_based else list_get(args, 0)
expression = list_get(args, 1) if unit_based else list_get(args, 1)
unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)
return inner_func

View file

@ -111,6 +111,7 @@ def _unnest_to_explode_sql(self, expression):
self.sql(
exp.Lateral(
this=udtf(this=expression),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]),
)
)
@ -283,6 +284,7 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}",
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
WITH_PROPERTIES = {exp.AnonymousProperty}

View file

@ -115,6 +115,7 @@ class Presto(Dialect):
class Tokenizer(Tokenizer):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
"ROW": TokenType.STRUCT,
}

View file

@ -188,6 +188,8 @@ class Snowflake(Dialect):
}
class Generator(Generator):
CREATE_TRANSIENT = True
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),

View file

@ -20,6 +20,7 @@ class SQLite(Dialect):
KEYWORDS = {
**Tokenizer.KEYWORDS,
"VARBINARY": TokenType.BINARY,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}

View file

@ -1,5 +1,7 @@
import re
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, rename_func
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType
from sqlglot.generator import Generator
from sqlglot.helper import list_get
@ -27,6 +29,11 @@ DATE_DELTA_INTERVAL = {
}
DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})")
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
@ -42,26 +49,40 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
def parse_date_delta(exp_class):
def inner_func(args):
unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day")
return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit)
return inner_func
def parse_format(args):
fmt = list_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
return exp.NumberToStr(this=list_get(args, 0), format=fmt)
return exp.TimeToStr(
this=list_get(args, 0),
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.time_mapping)
),
)
def generate_date_delta(self, e):
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
def generate_format_sql(self, e):
fmt = (
e.args["format"]
if isinstance(e, exp.NumberToStr)
else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
)
return f"FORMAT({self.format_args(e.this, fmt)})"
class TSQL(Dialect):
null_ordering = "nulls_are_small"
time_format = "'yyyy-mm-dd hh:mm:ss'"
time_mapping = {
"yyyy": "%Y",
"yy": "%y",
"year": "%Y",
"qq": "%q",
"q": "%q",
@ -101,6 +122,8 @@ class TSQL(Dialect):
"H": "%-H",
"h": "%-I",
"S": "%f",
"yyyy": "%Y",
"yy": "%y",
}
convert_format_mapping = {
@ -143,6 +166,27 @@ class TSQL(Dialect):
"120": "%Y-%m-%d %H:%M:%S",
"121": "%Y-%m-%d %H:%M:%S.%f",
}
# not sure if complete
format_time_mapping = {
"y": "%B %Y",
"d": "%m/%d/%Y",
"H": "%-H",
"h": "%-I",
"s": "%Y-%m-%d %H:%M:%S",
"D": "%A,%B,%Y",
"f": "%A,%B,%Y %-I:%M %p",
"F": "%A,%B,%Y %-I:%M:%S %p",
"g": "%m/%d/%Y %-I:%M %p",
"G": "%m/%d/%Y %-I:%M:%S %p",
"M": "%B %-d",
"m": "%B %-d",
"O": "%Y-%m-%dT%H:%M:%S",
"u": "%Y-%M-%D %H:%M:%S%z",
"U": "%A, %B %D, %Y %H:%M:%S%z",
"T": "%-I:%M:%S %p",
"t": "%-I:%M",
"Y": "%a %Y",
}
class Tokenizer(Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
@ -166,6 +210,7 @@ class TSQL(Dialect):
"SQL_VARIANT": TokenType.VARIANT,
"NVARCHAR(MAX)": TokenType.TEXT,
"VARCHAR(MAX)": TokenType.TEXT,
"TOP": TokenType.TOP,
}
class Parser(Parser):
@ -173,8 +218,8 @@ class TSQL(Dialect):
**Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd),
"DATEDIFF": parse_date_delta(exp.DateDiff),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": tsql_format_time_lambda(exp.TimeToStr),
"GETDATE": exp.CurrentDate.from_arg_list,
@ -182,6 +227,7 @@ class TSQL(Dialect):
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": parse_format,
}
VAR_LENGTH_DATATYPES = {
@ -194,7 +240,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_field()
this = self._parse_column()
# Retrieve length of datatype and override to default if not specified
if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
@ -238,8 +284,10 @@ class TSQL(Dialect):
TRANSFORMS = {
**Generator.TRANSFORMS,
exp.DateAdd: lambda self, e: generate_date_delta(self, e),
exp.DateDiff: lambda self, e: generate_date_delta(self, e),
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: generate_format_sql,
exp.TimeToStr: generate_format_sql,
}

View file

@ -443,7 +443,7 @@ class Condition(Expression):
'x = 1 AND y = 1'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
@ -462,7 +462,7 @@ class Condition(Expression):
'x = 1 OR y = 1'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
opts (kwargs): other options to use to parse the input expressions.
@ -523,7 +523,7 @@ class Unionable(Expression):
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
expression (str | Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -543,7 +543,7 @@ class Unionable(Expression):
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
expression (str | Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -563,7 +563,7 @@ class Unionable(Expression):
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
expression (str or Expression): the SQL code string.
expression (str | Expression): the SQL code string.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -612,6 +612,7 @@ class Create(Expression):
"exists": False,
"properties": False,
"temporary": False,
"transient": False,
"replace": False,
"unique": False,
"materialized": False,
@ -910,7 +911,7 @@ class Join(Expression):
'JOIN x ON y = 1'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append (bool): if `True`, AND the new expressions to any existing expression.
@ -937,9 +938,45 @@ class Join(Expression):
return join
def using(self, *expressions, append=True, dialect=None, copy=True, **opts):
"""
Append to or set the USING expressions.
Example:
>>> import sqlglot
>>> sqlglot.parse_one("JOIN x", into=Join).using("foo", "bla").sql()
'JOIN x USING (foo, bla)'
Args:
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
append (bool): if `True`, concatenate the new expressions to the existing "using" list.
Otherwise, this resets the expression.
dialect (str): the dialect used to parse the input expressions.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
Returns:
Join: the modified join expression.
"""
join = _apply_list_builder(
*expressions,
instance=self,
arg="using",
append=append,
dialect=dialect,
copy=copy,
**opts,
)
if join.kind == "CROSS":
join.set("kind", None)
return join
class Lateral(UDTF):
arg_types = {"this": True, "outer": False, "alias": False}
arg_types = {"this": True, "view": False, "outer": False, "alias": False}
# Clickhouse FROM FINAL modifier
@ -1093,7 +1130,7 @@ class Subqueryable(Unionable):
'SELECT x FROM (SELECT x FROM tbl)'
Args:
alias (str or Identifier): an optional alias for the subquery
alias (str | Identifier): an optional alias for the subquery
copy (bool): if `False`, modify this expression instance in-place.
Returns:
@ -1138,9 +1175,9 @@ class Subqueryable(Unionable):
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
Args:
alias (str or Expression): the SQL code string to parse as the table name.
alias (str | Expression): the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_ (str or Expression): the SQL code string to parse as the table expression.
as_ (str | Expression): the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`.
append (bool): if `True`, add to any existing expressions.
@ -1295,7 +1332,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If a `From` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `From`.
append (bool): if `True`, add to any existing expressions.
@ -1328,7 +1365,7 @@ class Select(Subqueryable):
'SELECT x, COUNT(1) FROM tbl GROUP BY x'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Group`.
If nothing is passed in then a group by is not applied to the expression
@ -1364,7 +1401,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl ORDER BY x DESC'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Order`.
append (bool): if `True`, add to any existing expressions.
@ -1397,7 +1434,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl SORT BY x DESC'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `SORT`.
append (bool): if `True`, add to any existing expressions.
@ -1430,7 +1467,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl CLUSTER BY x DESC'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Cluster`.
append (bool): if `True`, add to any existing expressions.
@ -1463,7 +1500,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl LIMIT 10'
Args:
expression (str or int or Expression): the SQL code string to parse.
expression (str | int | Expression): the SQL code string to parse.
This can also be an integer.
If a `Limit` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Limit`.
@ -1494,7 +1531,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl OFFSET 10'
Args:
expression (str or int or Expression): the SQL code string to parse.
expression (str | int | Expression): the SQL code string to parse.
This can also be an integer.
If a `Offset` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
@ -1525,7 +1562,7 @@ class Select(Subqueryable):
'SELECT x, y'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
append (bool): if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
@ -1555,7 +1592,7 @@ class Select(Subqueryable):
'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
append (bool): if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
@ -1582,6 +1619,7 @@ class Select(Subqueryable):
self,
expression,
on=None,
using=None,
append=True,
join_type=None,
join_alias=None,
@ -1596,15 +1634,20 @@ class Select(Subqueryable):
>>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql()
'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y'
>>> Select().select("1").from_("a").join("b", using=["x", "y", "z"]).sql()
'SELECT 1 FROM a JOIN b USING (x, y, z)'
Use `join_type` to change the type of join:
>>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql()
'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y'
Args:
expression (str or Expression): the SQL code string to parse.
expression (str | Expression): the SQL code string to parse.
If an `Expression` instance is passed, it will be used as-is.
on (str or Expression): optionally specify the join criteria as a SQL string.
on (str | Expression): optionally specify the join "on" criteria as a SQL string.
If an `Expression` instance is passed, it will be used as-is.
using (str | Expression): optionally specify the join "using" criteria as a SQL string.
If an `Expression` instance is passed, it will be used as-is.
append (bool): if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
@ -1641,6 +1684,16 @@ class Select(Subqueryable):
on = and_(*ensure_list(on), dialect=dialect, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
*ensure_list(using),
instance=join,
arg="using",
append=append,
copy=copy,
**opts,
)
if join_alias:
join.set("this", alias_(join.args["this"], join_alias, table=True))
return _apply_list_builder(
@ -1661,7 +1714,7 @@ class Select(Subqueryable):
"SELECT x FROM tbl WHERE x = 'a' OR x < 'b'"
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append (bool): if `True`, AND the new expressions to any existing expression.
@ -1693,7 +1746,7 @@ class Select(Subqueryable):
'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
Multiple expressions are combined with an AND operator.
append (bool): if `True`, AND the new expressions to any existing expression.
@ -1744,7 +1797,7 @@ class Select(Subqueryable):
'CREATE TABLE x AS SELECT * FROM tbl'
Args:
table (str or Expression): the SQL code string to parse as the table name.
table (str | Expression): the SQL code string to parse as the table name.
If another `Expression` instance is passed, it will be used as-is.
properties (dict): an optional mapping of table properties
dialect (str): the dialect used to parse the input table.
@ -2620,6 +2673,10 @@ class StrToUnix(Func):
arg_types = {"this": True, "format": True}
class NumberToStr(Func):
arg_types = {"this": True, "format": True}
class Struct(Func):
arg_types = {"expressions": True}
is_var_len_args = True
@ -2775,7 +2832,7 @@ def maybe_parse(
(IDENTIFIER this: x, quoted: False)
Args:
sql_or_expression (str or Expression): the SQL code string or an expression
sql_or_expression (str | Expression): the SQL code string or an expression
into (Expression): the SQLGlot Expression to parse into
dialect (str): the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
@ -2950,9 +3007,9 @@ def union(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo UNION SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
left (str | Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
right (str | Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -2975,9 +3032,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo INTERSECT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
left (str | Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
right (str | Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -3000,9 +3057,9 @@ def except_(left, right, distinct=True, dialect=None, **opts):
'SELECT * FROM foo EXCEPT SELECT * FROM bla'
Args:
left (str or Expression): the SQL code string corresponding to the left-hand side.
left (str | Expression): the SQL code string corresponding to the left-hand side.
If an `Expression` instance is passed, it will be used as-is.
right (str or Expression): the SQL code string corresponding to the right-hand side.
right (str | Expression): the SQL code string corresponding to the right-hand side.
If an `Expression` instance is passed, it will be used as-is.
distinct (bool): set the DISTINCT flag if and only if this is true.
dialect (str): the dialect used to parse the input expression.
@ -3025,7 +3082,7 @@ def select(*expressions, dialect=None, **opts):
'SELECT col1, col2 FROM tbl'
Args:
*expressions (str or Expression): the SQL code string to parse as the expressions of a
*expressions (str | Expression): the SQL code string to parse as the expressions of a
SELECT statement. If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
@ -3047,7 +3104,7 @@ def from_(*expressions, dialect=None, **opts):
'SELECT col1, col2 FROM tbl'
Args:
*expressions (str or Expression): the SQL code string to parse as the FROM expressions of a
*expressions (str | Expression): the SQL code string to parse as the FROM expressions of a
SELECT statement. If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
@ -3132,7 +3189,7 @@ def condition(expression, dialect=None, **opts):
'SELECT * FROM tbl WHERE x = 1 AND y = 1'
Args:
*expression (str or Expression): the SQL code string to parse.
*expression (str | Expression): the SQL code string to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
@ -3159,7 +3216,7 @@ def and_(*expressions, dialect=None, **opts):
'x = 1 AND (y = 1 AND z = 1)'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@ -3179,7 +3236,7 @@ def or_(*expressions, dialect=None, **opts):
'x = 1 OR (y = 1 OR z = 1)'
Args:
*expressions (str or Expression): the SQL code strings to parse.
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@ -3199,7 +3256,7 @@ def not_(expression, dialect=None, **opts):
"NOT this_suit = 'black'"
Args:
expression (str or Expression): the SQL code strings to parse.
expression (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@ -3283,9 +3340,9 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
'foo AS bar'
Args:
expression (str or Expression): the SQL code strings to parse.
expression (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
alias (str or Identifier): the alias name to use. If the name has
alias (str | Identifier): the alias name to use. If the name has
special characters it is quoted.
table (bool): create a table alias, default false
dialect (str): the dialect used to parse the input expression.
@ -3322,9 +3379,9 @@ def subquery(expression, alias=None, dialect=None, **opts):
'SELECT x FROM (SELECT x FROM tbl) AS bar'
Args:
expression (str or Expression): the SQL code strings to parse.
expression (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
alias (str or Expression): the alias name to use.
alias (str | Expression): the alias name to use.
dialect (str): the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
@ -3340,8 +3397,8 @@ def column(col, table=None, quoted=None):
"""
Build a Column.
Args:
col (str or Expression): column name
table (str or Expression): table name
col (str | Expression): column name
table (str | Expression): table name
Returns:
Column: column instance
"""
@ -3355,9 +3412,9 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
"""Build a Table.
Args:
table (str or Expression): column name
db (str or Expression): db name
catalog (str or Expression): catalog name
table (str | Expression): column name
db (str | Expression): db name
catalog (str | Expression): catalog name
Returns:
Table: table instance
@ -3423,7 +3480,7 @@ def convert(value):
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S"))
datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z"))
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))

View file

@ -65,6 +65,9 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
# whether 'CREATE ... TRANSIENT ... TABLE' is allowed
# can override in dialects
CREATE_TRANSIENT = False
# whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
# always do union distinct or union all
@ -368,15 +371,14 @@ class Generator:
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
expression_sql = (
f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
)
expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression):
@ -716,15 +718,21 @@ class Generator:
def lateral_sql(self, expression):
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
return f"LATERAL{self.sep()}{this}"
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"LATERAL {this}"
alias = expression.args["alias"]
table = alias.name
table = f" {table}" if table else table
columns = self.expressions(alias, key="columns", flat=True)
columns = f" AS {columns}" if columns else ""
return f"{op_sql}{self.sep()}{this}{table}{columns}"
if expression.args.get("view"):
op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}")
return f"{op_sql}{self.sep()}{this}{table}{columns}"
return f"LATERAL {this}{table}{columns}"
def limit_sql(self, expression):
this = self.sql(expression, "this")

View file

@ -211,21 +211,26 @@ def _qualify_columns(scope, resolver):
if column_table:
column.set("table", exp.to_identifier(column_table))
columns_missing_from_scope = []
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
column_table = column.table
column_name = column.name
if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
columns_missing_from_scope.append(column)
if column_table or column.parent is ordered or column_name not in resolver.all_columns:
continue
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column):
if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
columns_missing_from_scope.append(column)
column_table = resolver.get_table(column_name)
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column.name}")
column.set("table", exp.to_identifier(column_table))
column.set("table", exp.to_identifier(column_table))
def _expand_stars(scope, resolver):

View file

@ -232,7 +232,7 @@ class Scope:
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint)
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if (
not ancestor
or column.table

View file

@ -131,6 +131,7 @@ class Parser:
TokenType.ALTER,
TokenType.ALWAYS,
TokenType.ANTI,
TokenType.APPLY,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
@ -190,6 +191,7 @@ class Parser:
TokenType.TABLE,
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUNCATE,
@ -204,7 +206,7 @@ class Parser:
*TYPE_TOKENS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
@ -685,6 +687,7 @@ class Parser:
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match(TokenType.TRANSIENT)
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
@ -723,6 +726,7 @@ class Parser:
exists=exists,
properties=properties,
temporary=temporary,
transient=transient,
replace=replace,
unique=unique,
materialized=materialized,
@ -1057,8 +1061,8 @@ class Parser:
return self._parse_set_operations(this) if this else None
def _parse_with(self):
if not self._match(TokenType.WITH):
def _parse_with(self, skip_with_token=False):
if not skip_with_token and not self._match(TokenType.WITH):
return None
recursive = self._match(TokenType.RECURSIVE)
@ -1167,28 +1171,53 @@ class Parser:
return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
def _parse_lateral(self):
if not self._match(TokenType.LATERAL):
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
if outer_apply or cross_apply:
this = self._parse_select(table=True)
view = None
outer = not cross_apply
elif self._match(TokenType.LATERAL):
this = self._parse_select(table=True)
view = self._match(TokenType.VIEW)
outer = self._match(TokenType.OUTER)
else:
return None
subquery = self._parse_select(table=True)
if not this:
this = self._parse_function()
if subquery:
return self.expression(exp.Lateral, this=subquery)
table_alias = self._parse_id_var(any_token=False)
self._match(TokenType.VIEW)
outer = self._match(TokenType.OUTER)
columns = None
if self._match(TokenType.ALIAS):
columns = self._parse_csv(self._parse_id_var)
elif self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
self._match(TokenType.R_PAREN)
return self.expression(
expression = self.expression(
exp.Lateral,
this=self._parse_function(),
this=this,
view=view,
outer=outer,
alias=self.expression(
exp.TableAlias,
this=self._parse_id_var(any_token=False),
columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None),
this=table_alias,
columns=columns,
),
)
if outer_apply or cross_apply:
return self.expression(
exp.Join,
this=expression,
side=None if cross_apply else "LEFT",
)
return expression
def _parse_join_side_and_kind(self):
return (
self._match(TokenType.NATURAL) and self._prev,
@ -1196,10 +1225,10 @@ class Parser:
self._match_set(self.JOIN_KINDS) and self._prev,
)
def _parse_join(self):
def _parse_join(self, skip_join_token=False):
natural, side, kind = self._parse_join_side_and_kind()
if not self._match(TokenType.JOIN):
if not skip_join_token and not self._match(TokenType.JOIN):
return None
kwargs = {"this": self._parse_table()}
@ -1425,13 +1454,13 @@ class Parser:
unpivot=unpivot,
)
def _parse_where(self):
if not self._match(TokenType.WHERE):
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
return None
return self.expression(exp.Where, this=self._parse_conjunction())
def _parse_group(self):
if not self._match(TokenType.GROUP_BY):
def _parse_group(self, skip_group_by_token=False):
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
return self.expression(
exp.Group,
@ -1457,8 +1486,8 @@ class Parser:
return self.expression(exp.Tuple, expressions=grouping_set)
return self._parse_id_var()
def _parse_having(self):
if not self._match(TokenType.HAVING):
def _parse_having(self, skip_having_token=False):
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
@ -1467,8 +1496,8 @@ class Parser:
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
def _parse_order(self, this=None):
if not self._match(TokenType.ORDER_BY):
def _parse_order(self, this=None, skip_order_token=False):
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
@ -1502,7 +1531,11 @@ class Parser:
def _parse_limit(self, this=None, top=False):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
return self.expression(exp.Limit, this=this, expression=self._parse_number())
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
if limit_paren:
self._match(TokenType.R_PAREN)
return limit_exp
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@ -2136,7 +2169,7 @@ class Parser:
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
def _parse_convert(self, strict):
this = self._parse_field()
this = self._parse_column()
if self._match(TokenType.USING):
to = self.expression(exp.CharacterSet, this=self._parse_var())
elif self._match(TokenType.COMMA):

View file

@ -43,5 +43,4 @@ def format_time(string, mapping, trie=None):
if result and end > size:
chunks.append(chars)
return "".join(mapping.get(chars, chars) for chars in chunks)

View file

@ -107,6 +107,7 @@ class TokenType(AutoName):
ANALYZE = auto()
ANTI = auto()
ANY = auto()
APPLY = auto()
ARRAY = auto()
ASC = auto()
AT_TIME_ZONE = auto()
@ -256,6 +257,7 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
TRANSIENT = auto()
TOP = auto()
THEN = auto()
TRUE = auto()
@ -560,6 +562,7 @@ class Tokenizer(metaclass=_Tokenizer):
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
"TRANSIENT": TokenType.TRANSIENT,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
@ -582,6 +585,7 @@ class Tokenizer(metaclass=_Tokenizer):
"WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE,
"WITHIN GROUP": TokenType.WITHIN_GROUP,
"WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
"APPLY": TokenType.APPLY,
"ARRAY": TokenType.ARRAY,
"BOOL": TokenType.BOOLEAN,
"BOOLEAN": TokenType.BOOLEAN,