1
0
Fork 0

Merging upstream version 11.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:23:26 +01:00
parent fdac67ef7f
commit ba0f3f0bfa
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
112 changed files with 126100 additions and 230 deletions

View file

@ -22,5 +22,17 @@ jobs:
python -m pip install --upgrade pip
make install-dev
- name: Run checks (linter, code style, tests)
run: make check
- name: Update documentation
run: |
make check
make docs
git add docs
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git commit -m "CI: Auto-generated documentation" -a | exit 0
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' }}
- name: Push changes
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }}
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View file

@ -1,6 +1,39 @@
Changelog
=========
v11.0.0
------
Changes:
- Breaking: Renamed ESCAPES to STRING_ESCAPES in the Tokenizer class.
- New: Deployed pdoc documentation page.
- New: Add support for read locking using the FOR UPDATE/SHARE syntax (e.g. MySQL).
- New: Added support for CASCADE, SET NULL and SET DEFAULT constraints.
- New: Added "cast" expression helper.
- New: Add support for transpiling Postgres GENERATE_SERIES into Presto SEQUENCE.
- Improvement: Fix tokenizing of identifier escapes.
- Improvement: Fix eliminate_subqueries [bug](https://github.com/tobymao/sqlglot/commit/b5df65e3fb5ee1ebc3cbab64b6d89598cf47a10b) related to unions.
- Improvement: IFNULL is now transpiled to COALESCE by default for every dialect.
- Improvement: Refactored the way properties are handled. Now it's easier to add them and specify their position in a SQL expression.
- Improvement: Fixed alias quoting bug.
- Improvement: Fixed CUBE / ROLLUP / GROUPING SETS parsing and generation.
- Improvement: Fixed get_or_raise Dialect/t.Type[Dialect] argument bug.
- Improvement: Improved python type hints.
v10.6.0
------

View file

@ -18,7 +18,7 @@ style:
check: style test
docs:
python pdoc/cli.py -o pdoc/docs
python pdoc/cli.py -o docs
docs-serve:
python pdoc/cli.py
python pdoc/cli.py --port 8002

1
docs/CNAME Normal file
View file

@ -0,0 +1 @@
sqlglot.com

7
docs/index.html Normal file
View file

@ -0,0 +1,7 @@
<!doctype html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="0; url=./sqlglot.html"/>
</head>
</html>

46
docs/search.js Normal file

File diff suppressed because one or more lines are too long

1226
docs/sqlglot.html Normal file

File diff suppressed because one or more lines are too long

506
docs/sqlglot/dataframe.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

400
docs/sqlglot/dialects.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1560
docs/sqlglot/diff.html Normal file

File diff suppressed because one or more lines are too long

877
docs/sqlglot/errors.html Normal file

File diff suppressed because one or more lines are too long

694
docs/sqlglot/executor.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

39484
docs/sqlglot/expressions.html Normal file

File diff suppressed because one or more lines are too long

9855
docs/sqlglot/generator.html Normal file

File diff suppressed because one or more lines are too long

1651
docs/sqlglot/helper.html Normal file

File diff suppressed because one or more lines are too long

931
docs/sqlglot/lineage.html Normal file

File diff suppressed because one or more lines are too long

264
docs/sqlglot/optimizer.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

8049
docs/sqlglot/parser.html Normal file

File diff suppressed because one or more lines are too long

1995
docs/sqlglot/planner.html Normal file

File diff suppressed because one or more lines are too long

1624
docs/sqlglot/schema.html Normal file

File diff suppressed because one or more lines are too long

408
docs/sqlglot/serde.html Normal file

File diff suppressed because one or more lines are too long

385
docs/sqlglot/time.html Normal file

File diff suppressed because one or more lines are too long

6712
docs/sqlglot/tokens.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

479
docs/sqlglot/trie.html Normal file

File diff suppressed because one or more lines are too long

View file

@ -26,9 +26,9 @@ if __name__ == "__main__":
opts = parser.parse_args()
opts.docformat = "google"
opts.modules = ["sqlglot"]
opts.footer_text = "Copyright (c) 2022 Toby Mao"
opts.footer_text = "Copyright (c) 2023 Toby Mao"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"]
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/tree/main/sqlglot/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli()

View file

@ -1,5 +1,6 @@
"""
.. include:: ../README.md
----
"""
@ -39,7 +40,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
__version__ = "10.6.3"
__version__ = "11.0.1"
pretty = False
"""Whether to format generated SQL by default."""

View file

@ -2,6 +2,8 @@
from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
def _date_add(expression_class):
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
def func(args):
interval = seq_get(args, 1)
return expression_class(
@ -27,26 +31,26 @@ def _date_add(expression_class):
return func
def _date_trunc(args):
def _date_trunc(args: t.Sequence) -> exp.Expression:
unit = seq_get(args, 1)
if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(data_type, kind):
def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression):
this = self.sql(expression, "this")
unit = self.sql(expression, "unit") or "'day'"
expression = self.sql(expression, "expression")
return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
return func
def _derived_table_values_to_unnest(self, expression):
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
expression = transforms.remove_precision_parameterized_types(expression)
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression):
return self.unnest_sql(unnest_exp)
def _returnsproperty_sql(self, expression):
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression):
return f"RETURNS {this}"
def _create_sql(self, expression):
kind = expression.args.get("kind")
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy()
@ -89,6 +93,29 @@ def _create_sql(self, expression):
return self.create_sql(expression)
def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
"""Remove references to unnest table aliases since bigquery doesn't allow them.
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
unnests = {
unnest.alias
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
if isinstance(unnest, exp.Unnest) and unnest.alias
}
if unnests:
expression = expression.copy()
for select in expression.expressions:
for column in select.find_all(exp.Column):
if column.table in unnests:
column.set("table", None)
return expression
class BigQuery(Dialect):
unnest_column_only = True
time_mapping = {
@ -110,7 +137,7 @@ class BigQuery(Dialect):
]
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
@ -190,6 +217,9 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),

View file

@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
def _lower_func(sql):
def _lower_func(sql: str) -> str:
index = sql.index("(")
return sql[:index].lower() + sql[index:]

View file

@ -11,6 +11,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
class Dialects(str, Enum):
DIALECT = ""
@ -37,14 +39,16 @@ class Dialects(str, Enum):
class _Dialect(type):
classes: t.Dict[str, Dialect] = {}
classes: t.Dict[str, t.Type[Dialect]] = {}
@classmethod
def __getitem__(cls, key):
def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key]
@classmethod
def get(cls, key, default=None):
def get(
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
) -> t.Optional[t.Type[Dialect]]:
return cls.classes.get(key, default)
def __new__(cls, clsname, bases, attrs):
@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
generator_class = None
@classmethod
def get_or_raise(cls, dialect):
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect:
return cls
if isinstance(dialect, _Dialect):
@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
return result
@classmethod
def format_time(cls, expression):
def format_time(
cls, expression: t.Optional[str | exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
format_time(
@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
)
return expression
def parse(self, sql, **opts):
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts):
def parse_into(
self, expression_type: exp.IntoType, sql: str, **opts
) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
def generate(self, expression, **opts):
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
def transpile(self, code, **opts):
return self.generate(self.parse(code), **opts)
def transpile(self, sql: str, **opts) -> t.List[str]:
return [self.generate(expression, **opts) for expression in self.parse(sql)]
@property
def tokenizer(self):
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class()
self._tokenizer = self.tokenizer_class() # type: ignore
return self._tokenizer
def parser(self, **opts):
return self.parser_class(
def parser(self, **opts) -> Parser:
return self.parser_class( # type: ignore
**{
"index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only,
@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
},
)
def generator(self, **opts):
return self.generator_class(
def generator(self, **opts) -> Generator:
return self.generator_class( # type: ignore
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPES[0],
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
)
if t.TYPE_CHECKING:
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name):
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
def _rename(self, expression):
args = flatten(expression.args.values())
return f"{self.normalize_func(name)}({self.format_args(*args)})"
@ -214,32 +222,34 @@ def rename_func(name):
return _rename
def approx_count_distinct_sql(self, expression):
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
def if_sql(self, expression):
def if_sql(self: Generator, expression: exp.If) -> str:
expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false")
)
return f"IF({expressions})"
def arrow_json_extract_sql(self, expression):
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression):
def arrow_json_extract_scalar_sql(
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
) -> str:
return self.binary(expression, "->>")
def inline_array_sql(self, expression):
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression)}]"
def no_ilike_sql(self, expression):
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(
this=exp.Lower(this=expression.this),
@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
)
def no_paren_current_date_sql(self, expression):
def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql(self, expression):
def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
if expression.args.get("recursive"):
self.unsupported("Recursive CTEs are unsupported")
expression.args["recursive"] = False
return self.with_sql(expression)
def no_safe_divide_sql(self, expression):
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql(self, expression):
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
self.unsupported("TABLESAMPLE unsupported")
return self.sql(expression.this)
def no_pivot_sql(self, expression):
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
def no_trycast_sql(self, expression):
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
def no_properties_sql(self, expression):
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported")
return ""
def str_position_sql(self, expression):
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
@ -294,13 +304,15 @@ def str_position_sql(self, expression):
return f"STRPOS({this}, {substr})"
def struct_extract_sql(self, expression):
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
return f"{this}.{struct_key}"
def var_map_sql(self, expression, map_func_name="MAP"):
def var_map_sql(
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
) -> str:
keys = expression.args["keys"]
values = expression.args["values"]
@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None):
def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[t.Sequence], E]:
"""Helper used for time expressions.
Args
exp_class (Class): the expression class to instantiate
dialect (string): sql dialect
default (Option[bool | str]): the default format, True being time
Args:
exp_class: the expression class to instantiate.
dialect: target sql dialect.
default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
"""
def _format_time(args):
def _format_time(args: t.Sequence):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
seq_get(args, 1)
or (Dialect[dialect].time_format if default is True else default or None)
),
)
return _format_time
def create_with_partitions_sql(self, expression):
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
"""
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
return self.create_sql(expression)
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
def parse_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
def locate_to_strposition(args):
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@ -379,22 +399,22 @@ def locate_to_strposition(args):
)
def strposition_to_locate_sql(self, expression):
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
return f"LOCATE({args})"
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def trim_sql(self, expression):
def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import re
import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import (
)
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
return exp.UnixToTime.from_arg_list(args)
return format_time_lambda(exp.StrToTime, "drill")(args)
def _str_to_time_sql(self, expression):
def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind):
def func(self, expression):
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
unit = exp.Var(this=expression.text("unit").upper() or "DAY")
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func
def if_sql(self, expression):
def if_sql(self: generator.Generator, expression: exp.If) -> str:
"""
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF.
@ -61,7 +56,7 @@ def if_sql(self, expression):
return f"`IF`({expressions})"
def _str_to_date(self, expression):
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.date_format:
@ -111,7 +106,7 @@ class Drill(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"]
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
class Parser(parser.Parser):
@ -168,10 +163,10 @@ class Drill(Dialect):
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
def normalize_func(self, name):
def normalize_func(self, name: str) -> str:
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"

View file

@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression):
def _ts_or_ds_add(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
this = expression.args.get("this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _ts_or_ds_to_date_sql(self, expression):
@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression):
def _date_add(self, expression):
this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + INTERVAL {e} {unit}"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self, expression):

View file

@ -172,7 +172,7 @@ class Hive(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
KEYWORDS = {

View file

@ -89,8 +89,9 @@ def _date_add_sql(kind):
def func(self, expression):
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression")
return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func
@ -117,7 +118,7 @@ class MySQL(Dialect):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
ESCAPES = ["'", "\\"]
STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]

View file

@ -40,8 +40,7 @@ def _date_add_sql(kind):
expression = expression.copy()
expression.args["is_string"] = True
expression = self.sql(expression)
return f"{this} {kind} INTERVAL {expression} {unit}"
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
return func

View file

@ -37,11 +37,10 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,

View file

@ -180,7 +180,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
ESCAPES = ["\\", "'"]
STRING_ESCAPES = ["\\", "'"]
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
@ -191,6 +191,7 @@ class Snowflake(Dialect):
**tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
@ -222,6 +223,7 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
}
TYPE_MAPPING = {
@ -294,3 +296,12 @@ class Snowflake(Dialect):
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
start = expression.args.get("start")
start = f" START {start}" if start else ""
increment = expression.args.get("increment")
increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}"

View file

@ -157,6 +157,7 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(

View file

@ -49,7 +49,6 @@ class SQLite(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
class Parser(parser.Parser):

View file

@ -1,5 +1,6 @@
"""
.. include:: ../posts/sql_diff.md
----
"""

View file

@ -7,10 +7,17 @@ from sqlglot.helper import AutoName
class ErrorLevel(AutoName):
IGNORE = auto() # Ignore any parser errors
WARN = auto() # Log any parser errors with ERROR level
RAISE = auto() # Collect all parser errors and raise a single exception
IMMEDIATE = auto() # Immediately raise an exception on the first parser error
IGNORE = auto()
"""Ignore all errors."""
WARN = auto()
"""Log all errors."""
RAISE = auto()
"""Collect all errors and raise a single exception."""
IMMEDIATE = auto()
"""Immediately raise an exception on the first error found."""
class SqlglotError(Exception):

View file

@ -1,5 +1,6 @@
"""
.. include:: ../../posts/python_sql_engine.md
----
"""

View file

@ -408,7 +408,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
class Python(Dialect):
class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"]
STRING_ESCAPES = ["\\"]
class Generator(generator.Generator):
TRANSFORMS = {

View file

@ -6,6 +6,7 @@ Every AST node in SQLGlot is represented by a subclass of `Expression`.
This module contains the implementation of all supported `Expression` types. Additionally,
it exposes a number of helper functions, which are mainly used to programmatically build
SQL expressions, such as `sqlglot.expressions.select`.
----
"""
@ -137,6 +138,8 @@ class Expression(metaclass=_Expression):
return field
if isinstance(field, (Identifier, Literal, Var)):
return field.this
if isinstance(field, (Star, Null)):
return field.name
return ""
@property
@ -176,13 +179,11 @@ class Expression(metaclass=_Expression):
return self.text("alias")
@property
def name(self):
def name(self) -> str:
return self.text("this")
@property
def alias_or_name(self):
if isinstance(self, Null):
return "NULL"
return self.alias or self.name
@property
@ -589,12 +590,11 @@ class Expression(metaclass=_Expression):
return load(obj)
if t.TYPE_CHECKING:
IntoType = t.Union[
IntoType = t.Union[
str,
t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]],
]
]
class Condition(Expression):
@ -939,7 +939,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "start": False, "increment": False}
arg_types = {"this": False, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind):
@ -2390,7 +2390,7 @@ class Star(Expression):
arg_types = {"except": False, "replace": False}
@property
def name(self):
def name(self) -> str:
return "*"
@property
@ -2413,6 +2413,10 @@ class Placeholder(Expression):
class Null(Condition):
arg_types: t.Dict[str, t.Any] = {}
@property
def name(self) -> str:
return "NULL"
class Boolean(Condition):
pass
@ -2644,7 +2648,9 @@ class Div(Binary):
class Dot(Binary):
pass
@property
def name(self) -> str:
return self.expression.name
class DPipe(Binary):
@ -2961,7 +2967,7 @@ class Cast(Func):
arg_types = {"this": True, "to": True}
@property
def name(self):
def name(self) -> str:
return self.this.name
@property
@ -4027,17 +4033,39 @@ def paren(expression) -> Paren:
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None:
@t.overload
def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
...
@t.overload
def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
...
def to_identifier(name, quoted=None):
"""Builds an identifier.
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
Returns:
The identifier ast node.
"""
if name is None:
return None
if isinstance(alias, Identifier):
identifier = alias
elif isinstance(alias, str):
if quoted is None:
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted)
if isinstance(name, Identifier):
identifier = name
elif isinstance(name, str):
identifier = Identifier(
this=name,
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
)
else:
raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
return identifier
@ -4112,20 +4140,31 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return Column(this=column_name, table=table_name, **kwargs)
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
"""
Create an Alias expression.
def alias_(
expression: str | Expression,
alias: str | Identifier,
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
**opts,
):
"""Create an Alias expression.
Example:
>>> alias_('foo', 'bar').sql()
'foo AS bar'
>>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql()
'(SELECT 1, 2) AS bar(a, b)'
Args:
expression (str | Expression): the SQL code strings to parse.
expression: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
alias (str | Identifier): the alias name to use. If the name has
alias: the alias name to use. If the name has
special characters it is quoted.
table (bool): create a table alias, default false
dialect (str): the dialect used to parse the input expression.
table: Whether or not to create a table alias, can also be a list of columns.
quoted: whether or not to quote the alias
dialect: the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions.
Returns:
@ -4135,8 +4174,14 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted)
if table:
expression.set("alias", TableAlias(this=alias))
return expression
table_alias = TableAlias(this=alias)
exp.set("alias", table_alias)
if not isinstance(table, bool):
for column in table:
table_alias.append("columns", to_identifier(column, quoted=quoted))
return exp
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import re
import typing as t
from sqlglot import exp
@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
class Generator:
"""
@ -28,7 +31,8 @@ class Generator:
identify (bool): if set to True all identifiers will be delimited by the corresponding
character.
normalize (bool): if set to True all identifiers will lower cased
escape (str): specifies an escape character. Default: '.
string_escape (str): specifies a string escape character. Default: '.
identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
@ -85,6 +89,9 @@ class Generator:
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
# Whether or not create function uses an AS before the def.
CREATE_FUNCTION_AS = True
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -154,7 +161,8 @@ class Generator:
"identifier_end",
"identify",
"normalize",
"escape",
"string_escape",
"identifier_escape",
"pad",
"index_offset",
"unnest_column_only",
@ -167,6 +175,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
"_escaped_identifier_end",
"_leading_comma",
"_max_text_width",
"_comments",
@ -183,7 +192,8 @@ class Generator:
identifier_end=None,
identify=False,
normalize=False,
escape=None,
string_escape=None,
identifier_escape=None,
pad=2,
indent=2,
index_offset=0,
@ -208,7 +218,8 @@ class Generator:
self.identifier_end = identifier_end or '"'
self.identify = identify
self.normalize = normalize
self.escape = escape or "'"
self.string_escape = string_escape or "'"
self.identifier_escape = identifier_escape or '"'
self.pad = pad
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
@ -219,8 +230,9 @@ class Generator:
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
self._replace_backslash = self.escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end
self._replace_backslash = self.string_escape == "\\"
self._escaped_quote_end = self.string_escape + self.quote_end
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
@ -441,6 +453,9 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
this = ""
if expression.this is not None:
this = " ALWAYS " if expression.this else " BY DEFAULT "
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
@ -449,9 +464,7 @@ class Generator:
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
return (
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
)
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -496,7 +509,12 @@ class Generator:
properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}"
if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
expression_sql = f" AS{expression_sql}"
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@ -701,6 +719,7 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@ -1121,7 +1140,7 @@ class Generator:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = BACKSLASH_RE.sub(r"\\\\", text)
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
@ -1486,9 +1505,16 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
this = self.sql(expression, "this")
this = f" {this}" if this else ""
unit = self.sql(expression, "unit")
this = expression.args.get("this")
if this:
this = (
f" {this}"
if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
else f" ({this})"
)
else:
this = ""
unit = expression.args.get("unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}"

View file

@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
@ -38,7 +39,7 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None,
) -> Node:
"""Build the lineage graph for a column of a SQL query.

View file

@ -255,12 +255,23 @@ class TypeAnnotator:
for name, source in scope.sources.items():
if not isinstance(source, Scope):
continue
if isinstance(source.expression, exp.Values):
if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = {
alias: column
for alias, column in zip(
source.expression.alias_column_names,
source.expression.expressions[0].expressions,
values,
)
}
else:
@ -272,7 +283,7 @@ class TypeAnnotator:
source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
elif source:
elif source and col.table in selects:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)

View file

@ -0,0 +1,34 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
def expand_laterals(expression: exp.Expression) -> exp.Expression:
"""
Expand lateral column alias references.
This assumes `qualify_columns` as already run.
Example:
>>> import sqlglot
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
>>> expression = sqlglot.parse_one(sql)
>>> expand_laterals(expression).sql()
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
Args:
expression: expression to optimize
Returns:
optimized expression
"""
for select in expression.find_all(exp.Select):
alias_to_expression: t.Dict[str, exp.Expression] = {}
for projection in select.expressions:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@ -22,6 +23,8 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
expand_laterals,
validate_qualify_columns,
pushdown_projections,
normalize,
unnest_subqueries,

View file

@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
SELECT_ALL = object()
# Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_")
DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression):
@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)
if removed:
@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
]
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections)

View file

@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver)
_qualify_outputs(scope)
_check_unknown_tables(scope)
return expression
def validate_qualify_columns(expression):
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
def _pop_table_column_aliases(derived_tables):
"""
Remove table column aliases.
@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
if not column_table:
column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", exp.to_identifier(column_table))
@ -231,9 +240,7 @@ def _qualify_columns(scope, resolver):
for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name)
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column.name}")
if column_table:
column.set("table", exp.to_identifier(column_table))
@ -322,11 +329,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections)
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
class _Resolver:
"""
Helper for resolving columns.

View file

@ -2,7 +2,7 @@ import itertools
from sqlglot import alias, exp
from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
sequence = itertools.count()
next_name = lambda: f"_q_{next(sequence)}"
for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"):
@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
source.this if identifier else next_name(),
table=True,
)
)
@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)}
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_name())
return expression

View file

@ -237,6 +237,8 @@ class Scope:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if (
not ancestor
# Window functions can have an ORDER BY clause
or not isinstance(ancestor.parent, exp.Select)
or column.table
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
):
@ -479,7 +481,7 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
_set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
else:
@ -509,6 +511,22 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
def _set_udtf_scope(scope):
parent = scope.expression.parent
from_ = parent.args.get("from")
if not from_:
return
for table in from_.expressions:
if isinstance(table, exp.Table):
scope.tables.append(table)
elif isinstance(table, exp.Subquery):
scope.subqueries.append(table)
_add_table_sources(scope)
_traverse_subqueries(scope)
def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {}
is_cte = scope_type == ScopeType.CTE

View file

@ -194,6 +194,7 @@ class Parser(metaclass=_Parser):
TokenType.INTERVAL,
TokenType.LAZY,
TokenType.LEADING,
TokenType.LEFT,
TokenType.LOCAL,
TokenType.MATERIALIZED,
TokenType.MERGE,
@ -208,6 +209,7 @@ class Parser(metaclass=_Parser):
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
TokenType.SCHEMA,
@ -237,8 +239,10 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
TokenType.LEFT,
TokenType.NATURAL,
TokenType.OFFSET,
TokenType.RIGHT,
TokenType.WINDOW,
}
@ -258,6 +262,8 @@ class Parser(metaclass=_Parser):
TokenType.IDENTIFIER,
TokenType.INDEX,
TokenType.ISNULL,
TokenType.ILIKE,
TokenType.LIKE,
TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
@ -971,7 +977,8 @@ class Parser(metaclass=_Parser):
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN")
expression = self._parse_statement()
@ -2163,7 +2170,9 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
limit_exp = self.expression(
exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
)
if limit_paren:
self._match_r_paren()
@ -2740,7 +2749,22 @@ class Parser(metaclass=_Parser):
kind: exp.Expression
if self._match(TokenType.AUTO_INCREMENT):
if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
start = None
increment = None
if self._match(TokenType.L_PAREN, advance=False):
args = self._parse_wrapped_csv(self._parse_bitwise)
start = seq_get(args, 0)
increment = seq_get(args, 1)
elif self._match_text_seq("START"):
start = self._parse_bitwise()
self._match_text_seq("INCREMENT")
increment = self._parse_bitwise()
if start and increment:
kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
else:
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
constraint = self._parse_wrapped(self._parse_conjunction)
@ -3294,8 +3318,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_id_vars()
return self._parse_csv(self._parse_id_var)
return self._parse_wrapped_csv(self._parse_column)
return self._parse_csv(self._parse_column)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
@ -3442,7 +3466,7 @@ class Parser(metaclass=_Parser):
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return None
return self._parse_as_command(self._prev)
exists = self._parse_exists()
this = self._parse_table(schema=True)

View file

@ -357,7 +357,8 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._ESCAPES = set(klass.ESCAPES)
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS
@ -429,9 +430,13 @@ class Tokenizer(metaclass=_Tokenizer):
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
ESCAPES = ["'"]
STRING_ESCAPES = ["'"]
_ESCAPES: t.Set[str] = set()
_STRING_ESCAPES: t.Set[str] = set()
IDENTIFIER_ESCAPES = ['"']
_IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = {
**{
@ -469,6 +474,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN,
@ -691,6 +697,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
"COPY": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
@ -744,7 +751,7 @@ class Tokenizer(metaclass=_Tokenizer):
)
def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES
self._replace_backslash = "\\" in self._STRING_ESCAPES
self.reset()
def reset(self) -> None:
@ -1046,12 +1053,25 @@ class Tokenizer(metaclass=_Tokenizer):
return True
def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end:
text = ""
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
while True:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
text += identifier_end # type: ignore
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
continue
break
text += self._char # type: ignore
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
@ -1072,9 +1092,9 @@ class Tokenizer(metaclass=_Tokenizer):
while True:
if (
self._char in self._ESCAPES
self._char in self._STRING_ESCAPES
and self._peek
and (self._peek == delimiter or self._peek in self._ESCAPES)
and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
):
text += self._peek
self._advance(2)

Some files were not shown because too many files have changed in this diff Show more