1
0
Fork 0

Adding upstream version 11.0.1.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-13 15:22:50 +01:00
parent ab1b3ea4d6
commit e09ae33d10
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
112 changed files with 126100 additions and 230 deletions

View file

@ -22,5 +22,17 @@ jobs:
python -m pip install --upgrade pip python -m pip install --upgrade pip
make install-dev make install-dev
- name: Run checks (linter, code style, tests) - name: Run checks (linter, code style, tests)
run: make check
- name: Update documentation
run: | run: |
make check make docs
git add docs
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git commit -m "CI: Auto-generated documentation" -a | exit 0
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' }}
- name: Push changes
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }}
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}

View file

@ -1,6 +1,39 @@
Changelog Changelog
========= =========
v11.0.0
------
Changes:
- Breaking: Renamed ESCAPES to STRING_ESCAPES in the Tokenizer class.
- New: Deployed pdoc documentation page.
- New: Add support for read locking using the FOR UPDATE/SHARE syntax (e.g. MySQL).
- New: Added support for CASCADE, SET NULL and SET DEFAULT constraints.
- New: Added "cast" expression helper.
- New: Add support for transpiling Postgres GENERATE_SERIES into Presto SEQUENCE.
- Improvement: Fix tokenizing of identifier escapes.
- Improvement: Fix eliminate_subqueries [bug](https://github.com/tobymao/sqlglot/commit/b5df65e3fb5ee1ebc3cbab64b6d89598cf47a10b) related to unions.
- Improvement: IFNULL is now transpiled to COALESCE by default for every dialect.
- Improvement: Refactored the way properties are handled. Now it's easier to add them and specify their position in a SQL expression.
- Improvement: Fixed alias quoting bug.
- Improvement: Fixed CUBE / ROLLUP / GROUPING SETS parsing and generation.
- Improvement: Fixed get_or_raise Dialect/t.Type[Dialect] argument bug.
- Improvement: Improved python type hints.
v10.6.0 v10.6.0
------ ------

View file

@ -18,7 +18,7 @@ style:
check: style test check: style test
docs: docs:
python pdoc/cli.py -o pdoc/docs python pdoc/cli.py -o docs
docs-serve: docs-serve:
python pdoc/cli.py python pdoc/cli.py --port 8002

1
docs/CNAME Normal file
View file

@ -0,0 +1 @@
sqlglot.com

7
docs/index.html Normal file
View file

@ -0,0 +1,7 @@
<!doctype html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="0; url=./sqlglot.html"/>
</head>
</html>

46
docs/search.js Normal file

File diff suppressed because one or more lines are too long

1226
docs/sqlglot.html Normal file

File diff suppressed because one or more lines are too long

506
docs/sqlglot/dataframe.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

400
docs/sqlglot/dialects.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1560
docs/sqlglot/diff.html Normal file

File diff suppressed because one or more lines are too long

877
docs/sqlglot/errors.html Normal file

File diff suppressed because one or more lines are too long

694
docs/sqlglot/executor.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

39484
docs/sqlglot/expressions.html Normal file

File diff suppressed because one or more lines are too long

9855
docs/sqlglot/generator.html Normal file

File diff suppressed because one or more lines are too long

1651
docs/sqlglot/helper.html Normal file

File diff suppressed because one or more lines are too long

931
docs/sqlglot/lineage.html Normal file

File diff suppressed because one or more lines are too long

264
docs/sqlglot/optimizer.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

8049
docs/sqlglot/parser.html Normal file

File diff suppressed because one or more lines are too long

1995
docs/sqlglot/planner.html Normal file

File diff suppressed because one or more lines are too long

1624
docs/sqlglot/schema.html Normal file

File diff suppressed because one or more lines are too long

408
docs/sqlglot/serde.html Normal file

File diff suppressed because one or more lines are too long

385
docs/sqlglot/time.html Normal file

File diff suppressed because one or more lines are too long

6712
docs/sqlglot/tokens.html Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

479
docs/sqlglot/trie.html Normal file

File diff suppressed because one or more lines are too long

View file

@ -26,9 +26,9 @@ if __name__ == "__main__":
opts = parser.parse_args() opts = parser.parse_args()
opts.docformat = "google" opts.docformat = "google"
opts.modules = ["sqlglot"] opts.modules = ["sqlglot"]
opts.footer_text = "Copyright (c) 2022 Toby Mao" opts.footer_text = "Copyright (c) 2023 Toby Mao"
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute() opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"] opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/tree/main/sqlglot/"]
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}): with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
cli() cli()

View file

@ -1,5 +1,6 @@
""" """
.. include:: ../README.md .. include:: ../README.md
---- ----
""" """
@ -39,7 +40,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression) T = t.TypeVar("T", bound=Expression)
__version__ = "10.6.3" __version__ = "11.0.1"
pretty = False pretty = False
"""Whether to format generated SQL by default.""" """Whether to format generated SQL by default."""

View file

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
Dialect, Dialect,
@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
E = t.TypeVar("E", bound=exp.Expression)
def _date_add(expression_class):
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
def func(args): def func(args):
interval = seq_get(args, 1) interval = seq_get(args, 1)
return expression_class( return expression_class(
@ -27,26 +31,26 @@ def _date_add(expression_class):
return func return func
def _date_trunc(args): def _date_trunc(args: t.Sequence) -> exp.Expression:
unit = seq_get(args, 1) unit = seq_get(args, 1)
if isinstance(unit, exp.Column): if isinstance(unit, exp.Column):
unit = exp.Var(this=unit.name) unit = exp.Var(this=unit.name)
return exp.DateTrunc(this=seq_get(args, 0), expression=unit) return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
def _date_add_sql(data_type, kind): def _date_add_sql(
data_type: str, kind: str
) -> t.Callable[[generator.Generator, exp.Expression], str]:
def func(self, expression): def func(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
unit = self.sql(expression, "unit") or "'day'" return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
expression = self.sql(expression, "expression")
return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
return func return func
def _derived_table_values_to_unnest(self, expression): def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From): if not isinstance(expression.unnest().parent, exp.From):
expression = transforms.remove_precision_parameterized_types(expression) expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression) return self.values_sql(expression)
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = [] structs = []
@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression):
return self.unnest_sql(unnest_exp) return self.unnest_sql(unnest_exp)
def _returnsproperty_sql(self, expression): def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this this = expression.this
if isinstance(this, exp.Schema): if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>" this = f"{this.this} <{self.expressions(this)}>"
@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression):
return f"RETURNS {this}" return f"RETURNS {this}"
def _create_sql(self, expression): def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
kind = expression.args.get("kind") kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty) returns = expression.find(exp.ReturnsProperty)
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
expression = expression.copy() expression = expression.copy()
@ -89,6 +93,29 @@ def _create_sql(self, expression):
return self.create_sql(expression) return self.create_sql(expression)
def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
"""Remove references to unnest table aliases since bigquery doesn't allow them.
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
unnests = {
unnest.alias
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
if isinstance(unnest, exp.Unnest) and unnest.alias
}
if unnests:
expression = expression.copy()
for select in expression.expressions:
for column in select.find_all(exp.Column):
if column.table in unnests:
column.set("table", None)
return expression
class BigQuery(Dialect): class BigQuery(Dialect):
unnest_column_only = True unnest_column_only = True
time_mapping = { time_mapping = {
@ -110,7 +137,7 @@ class BigQuery(Dialect):
] ]
COMMENTS = ["--", "#", ("/*", "*/")] COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPES = ["\\"] STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")] HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = { KEYWORDS = {
@ -190,6 +217,9 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"), exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql, exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"), exp.IntDiv: rename_func("DIV"),
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimeSub: _date_add_sql("TIME", "SUB"),

View file

@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType from sqlglot.tokens import TokenType
def _lower_func(sql): def _lower_func(sql: str) -> str:
index = sql.index("(") index = sql.index("(")
return sql[:index].lower() + sql[index:] return sql[:index].lower() + sql[index:]

View file

@ -11,6 +11,8 @@ from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer from sqlglot.tokens import Tokenizer
from sqlglot.trie import new_trie from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
class Dialects(str, Enum): class Dialects(str, Enum):
DIALECT = "" DIALECT = ""
@ -37,14 +39,16 @@ class Dialects(str, Enum):
class _Dialect(type): class _Dialect(type):
classes: t.Dict[str, Dialect] = {} classes: t.Dict[str, t.Type[Dialect]] = {}
@classmethod @classmethod
def __getitem__(cls, key): def __getitem__(cls, key: str) -> t.Type[Dialect]:
return cls.classes[key] return cls.classes[key]
@classmethod @classmethod
def get(cls, key, default=None): def get(
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
) -> t.Optional[t.Type[Dialect]]:
return cls.classes.get(key, default) return cls.classes.get(key, default)
def __new__(cls, clsname, bases, attrs): def __new__(cls, clsname, bases, attrs):
@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
generator_class = None generator_class = None
@classmethod @classmethod
def get_or_raise(cls, dialect): def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
if not dialect: if not dialect:
return cls return cls
if isinstance(dialect, _Dialect): if isinstance(dialect, _Dialect):
@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
return result return result
@classmethod @classmethod
def format_time(cls, expression): def format_time(
cls, expression: t.Optional[str | exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(expression, str): if isinstance(expression, str):
return exp.Literal.string( return exp.Literal.string(
format_time( format_time(
@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
) )
return expression return expression
def parse(self, sql, **opts): def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
def parse_into(self, expression_type, sql, **opts): def parse_into(
self, expression_type: exp.IntoType, sql: str, **opts
) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
def generate(self, expression, **opts): def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression) return self.generator(**opts).generate(expression)
def transpile(self, code, **opts): def transpile(self, sql: str, **opts) -> t.List[str]:
return self.generate(self.parse(code), **opts) return [self.generate(expression, **opts) for expression in self.parse(sql)]
@property @property
def tokenizer(self): def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"): if not hasattr(self, "_tokenizer"):
self._tokenizer = self.tokenizer_class() self._tokenizer = self.tokenizer_class() # type: ignore
return self._tokenizer return self._tokenizer
def parser(self, **opts): def parser(self, **opts) -> Parser:
return self.parser_class( return self.parser_class( # type: ignore
**{ **{
"index_offset": self.index_offset, "index_offset": self.index_offset,
"unnest_column_only": self.unnest_column_only, "unnest_column_only": self.unnest_column_only,
@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
}, },
) )
def generator(self, **opts): def generator(self, **opts) -> Generator:
return self.generator_class( return self.generator_class( # type: ignore
**{ **{
"quote_start": self.quote_start, "quote_start": self.quote_start,
"quote_end": self.quote_end, "quote_end": self.quote_end,
"identifier_start": self.identifier_start, "identifier_start": self.identifier_start,
"identifier_end": self.identifier_end, "identifier_end": self.identifier_end,
"escape": self.tokenizer_class.ESCAPES[0], "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
"index_offset": self.index_offset, "index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping, "time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie, "time_trie": self.inverse_time_trie,
@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
) )
if t.TYPE_CHECKING: DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
def rename_func(name): def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
def _rename(self, expression): def _rename(self, expression):
args = flatten(expression.args.values()) args = flatten(expression.args.values())
return f"{self.normalize_func(name)}({self.format_args(*args)})" return f"{self.normalize_func(name)}({self.format_args(*args)})"
@ -214,32 +222,34 @@ def rename_func(name):
return _rename return _rename
def approx_count_distinct_sql(self, expression): def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
if expression.args.get("accuracy"): if expression.args.get("accuracy"):
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
def if_sql(self, expression): def if_sql(self: Generator, expression: exp.If) -> str:
expressions = self.format_args( expressions = self.format_args(
expression.this, expression.args.get("true"), expression.args.get("false") expression.this, expression.args.get("true"), expression.args.get("false")
) )
return f"IF({expressions})" return f"IF({expressions})"
def arrow_json_extract_sql(self, expression): def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->") return self.binary(expression, "->")
def arrow_json_extract_scalar_sql(self, expression): def arrow_json_extract_scalar_sql(
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
) -> str:
return self.binary(expression, "->>") return self.binary(expression, "->>")
def inline_array_sql(self, expression): def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression)}]" return f"[{self.expressions(expression)}]"
def no_ilike_sql(self, expression): def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql( return self.like_sql(
exp.Like( exp.Like(
this=exp.Lower(this=expression.this), this=exp.Lower(this=expression.this),
@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
) )
def no_paren_current_date_sql(self, expression): def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this") zone = self.sql(expression, "this")
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
def no_recursive_cte_sql(self, expression): def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
if expression.args.get("recursive"): if expression.args.get("recursive"):
self.unsupported("Recursive CTEs are unsupported") self.unsupported("Recursive CTEs are unsupported")
expression.args["recursive"] = False expression.args["recursive"] = False
return self.with_sql(expression) return self.with_sql(expression)
def no_safe_divide_sql(self, expression): def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this") n = self.sql(expression, "this")
d = self.sql(expression, "expression") d = self.sql(expression, "expression")
return f"IF({d} <> 0, {n} / {d}, NULL)" return f"IF({d} <> 0, {n} / {d}, NULL)"
def no_tablesample_sql(self, expression): def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
self.unsupported("TABLESAMPLE unsupported") self.unsupported("TABLESAMPLE unsupported")
return self.sql(expression.this) return self.sql(expression.this)
def no_pivot_sql(self, expression): def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported") self.unsupported("PIVOT unsupported")
return self.sql(expression) return self.sql(expression)
def no_trycast_sql(self, expression): def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression) return self.cast_sql(expression)
def no_properties_sql(self, expression): def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
self.unsupported("Properties unsupported") self.unsupported("Properties unsupported")
return "" return ""
def str_position_sql(self, expression): def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
substr = self.sql(expression, "substr") substr = self.sql(expression, "substr")
position = self.sql(expression, "position") position = self.sql(expression, "position")
@ -294,13 +304,15 @@ def str_position_sql(self, expression):
return f"STRPOS({this}, {substr})" return f"STRPOS({this}, {substr})"
def struct_extract_sql(self, expression): def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
return f"{this}.{struct_key}" return f"{this}.{struct_key}"
def var_map_sql(self, expression, map_func_name="MAP"): def var_map_sql(
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
) -> str:
keys = expression.args["keys"] keys = expression.args["keys"]
values = expression.args["values"] values = expression.args["values"]
@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
return f"{map_func_name}({self.format_args(*args)})" return f"{map_func_name}({self.format_args(*args)})"
def format_time_lambda(exp_class, dialect, default=None): def format_time_lambda(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[t.Sequence], E]:
"""Helper used for time expressions. """Helper used for time expressions.
Args Args:
exp_class (Class): the expression class to instantiate exp_class: the expression class to instantiate.
dialect (string): sql dialect dialect: target sql dialect.
default (Option[bool | str]): the default format, True being time default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
""" """
def _format_time(args): def _format_time(args: t.Sequence):
return exp_class( return exp_class(
this=seq_get(args, 0), this=seq_get(args, 0),
format=Dialect[dialect].format_time( format=Dialect[dialect].format_time(
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) seq_get(args, 1)
or (Dialect[dialect].time_format if default is True else default or None)
), ),
) )
return _format_time return _format_time
def create_with_partitions_sql(self, expression): def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
""" """
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
return self.create_sql(expression) return self.create_sql(expression)
def parse_date_delta(exp_class, unit_mapping=None): def parse_date_delta(
def inner_func(args): exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.Sequence], E]:
def inner_func(args: t.Sequence) -> E:
unit_based = len(args) == 3 unit_based = len(args) == 3
this = seq_get(args, 2) if unit_based else seq_get(args, 0) this = seq_get(args, 2) if unit_based else seq_get(args, 0)
expression = seq_get(args, 1) if unit_based else seq_get(args, 1) expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
return exp_class(this=this, expression=expression, unit=unit) return exp_class(this=this, expression=expression, unit=unit)
return inner_func return inner_func
def locate_to_strposition(args): def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition( return exp.StrPosition(
this=seq_get(args, 1), this=seq_get(args, 1),
substr=seq_get(args, 0), substr=seq_get(args, 0),
@ -379,22 +399,22 @@ def locate_to_strposition(args):
) )
def strposition_to_locate_sql(self, expression): def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
args = self.format_args( args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position") expression.args.get("substr"), expression.this, expression.args.get("position")
) )
return f"LOCATE({args})" return f"LOCATE({args})"
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)"
def trim_sql(self, expression): def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this") target = self.sql(expression, "this")
trim_type = self.sql(expression, "position") trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression") remove_chars = self.sql(expression, "expression")

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
import typing as t
from sqlglot import exp, generator, parser, tokens from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import ( from sqlglot.dialects.dialect import (
@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import (
) )
def _to_timestamp(args): def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1 and args[0].is_number:
return exp.UnixToTime.from_arg_list(args)
return format_time_lambda(exp.StrToTime, "drill")(args)
def _str_to_time_sql(self, expression):
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
def _ts_or_ds_to_date_sql(self, expression): def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression) time_format = self.format_time(expression)
if time_format and time_format not in (Drill.time_format, Drill.date_format): if time_format and time_format not in (Drill.time_format, Drill.date_format):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)"
def _date_add_sql(kind): def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self, expression): def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY" unit = exp.Var(this=expression.text("unit").upper() or "DAY")
expression = self.sql(expression, "expression") return (
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func return func
def if_sql(self, expression): def if_sql(self: generator.Generator, expression: exp.If) -> str:
""" """
Drill requires backticks around certain SQL reserved words, IF being one of them, This function Drill requires backticks around certain SQL reserved words, IF being one of them, This function
adds the backticks around the keyword IF. adds the backticks around the keyword IF.
@ -61,7 +56,7 @@ def if_sql(self, expression):
return f"`IF`({expressions})" return f"`IF`({expressions})"
def _str_to_date(self, expression): def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this") this = self.sql(expression, "this")
time_format = self.format_time(expression) time_format = self.format_time(expression)
if time_format == Drill.date_format: if time_format == Drill.date_format:
@ -111,7 +106,7 @@ class Drill(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'"] QUOTES = ["'"]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPES = ["\\"] STRING_ESCAPES = ["\\"]
ENCODE = "utf-8" ENCODE = "utf-8"
class Parser(parser.Parser): class Parser(parser.Parser):
@ -168,10 +163,10 @@ class Drill(Dialect):
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TryCast: no_trycast_sql, exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
} }
def normalize_func(self, name): def normalize_func(self, name: str) -> str:
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"

View file

@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression):
def _ts_or_ds_add(self, expression): def _ts_or_ds_add(self, expression):
this = self.sql(expression, "this") this = expression.args.get("this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY" unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _ts_or_ds_to_date_sql(self, expression): def _ts_or_ds_to_date_sql(self, expression):
@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression):
def _date_add(self, expression): def _date_add(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
e = self.sql(expression, "expression")
unit = self.sql(expression, "unit").strip("'") or "DAY" unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + INTERVAL {e} {unit}" return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
def _array_sort_sql(self, expression): def _array_sort_sql(self, expression):

View file

@ -172,7 +172,7 @@ class Hive(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"'] QUOTES = ["'", '"']
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPES = ["\\"] STRING_ESCAPES = ["\\"]
ENCODE = "utf-8" ENCODE = "utf-8"
KEYWORDS = { KEYWORDS = {

View file

@ -89,8 +89,9 @@ def _date_add_sql(kind):
def func(self, expression): def func(self, expression):
this = self.sql(expression, "this") this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY" unit = expression.text("unit").upper() or "DAY"
expression = self.sql(expression, "expression") return (
return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
return func return func
@ -117,7 +118,7 @@ class MySQL(Dialect):
QUOTES = ["'", '"'] QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")] COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"] IDENTIFIERS = ["`"]
ESCAPES = ["'", "\\"] STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]

View file

@ -40,8 +40,7 @@ def _date_add_sql(kind):
expression = expression.copy() expression = expression.copy()
expression.args["is_string"] = True expression.args["is_string"] = True
expression = self.sql(expression) return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
return f"{this} {kind} INTERVAL {expression} {unit}"
return func return func

View file

@ -37,11 +37,10 @@ class Redshift(Postgres):
return this return this
class Tokenizer(Postgres.Tokenizer): class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"] STRING_ESCAPES = ["\\"]
KEYWORDS = { KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore **Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
"ENCODE": TokenType.ENCODE, "ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY, "GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY, "GEOGRAPHY": TokenType.GEOGRAPHY,

View file

@ -180,7 +180,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"] QUOTES = ["'", "$$"]
ESCAPES = ["\\", "'"] STRING_ESCAPES = ["\\", "'"]
SINGLE_TOKENS = { SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS, **tokens.Tokenizer.SINGLE_TOKENS,
@ -191,6 +191,7 @@ class Snowflake(Dialect):
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"EXCLUDE": TokenType.EXCEPT, "EXCLUDE": TokenType.EXCEPT,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"PUT": TokenType.COMMAND,
"RENAME": TokenType.REPLACE, "RENAME": TokenType.REPLACE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_NTZ": TokenType.TIMESTAMP,
@ -222,6 +223,7 @@ class Snowflake(Dialect):
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
exp.UnixToTime: _unix_to_time_sql, exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
} }
TYPE_MAPPING = { TYPE_MAPPING = {
@ -294,3 +296,12 @@ class Snowflake(Dialect):
kind = f" {kind_value}" if kind_value else "" kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}" this = f" {self.sql(expression, 'this')}"
return f"DESCRIBE{kind}{this}" return f"DESCRIBE{kind}{this}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
start = expression.args.get("start")
start = f" START {start}" if start else ""
increment = expression.args.get("increment")
increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}"

View file

@ -157,6 +157,7 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_AS = False
def cast_sql(self, expression: exp.Cast) -> str: def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type( if isinstance(expression.this, exp.Cast) and expression.this.is_type(

View file

@ -49,7 +49,6 @@ class SQLite(Dialect):
KEYWORDS = { KEYWORDS = {
**tokens.Tokenizer.KEYWORDS, **tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
} }
class Parser(parser.Parser): class Parser(parser.Parser):

View file

@ -1,5 +1,6 @@
""" """
.. include:: ../posts/sql_diff.md .. include:: ../posts/sql_diff.md
---- ----
""" """

View file

@ -7,10 +7,17 @@ from sqlglot.helper import AutoName
class ErrorLevel(AutoName): class ErrorLevel(AutoName):
IGNORE = auto() # Ignore any parser errors IGNORE = auto()
WARN = auto() # Log any parser errors with ERROR level """Ignore all errors."""
RAISE = auto() # Collect all parser errors and raise a single exception
IMMEDIATE = auto() # Immediately raise an exception on the first parser error WARN = auto()
"""Log all errors."""
RAISE = auto()
"""Collect all errors and raise a single exception."""
IMMEDIATE = auto()
"""Immediately raise an exception on the first error found."""
class SqlglotError(Exception): class SqlglotError(Exception):

View file

@ -1,5 +1,6 @@
""" """
.. include:: ../../posts/python_sql_engine.md .. include:: ../../posts/python_sql_engine.md
---- ----
""" """

View file

@ -408,7 +408,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
class Python(Dialect): class Python(Dialect):
class Tokenizer(tokens.Tokenizer): class Tokenizer(tokens.Tokenizer):
ESCAPES = ["\\"] STRING_ESCAPES = ["\\"]
class Generator(generator.Generator): class Generator(generator.Generator):
TRANSFORMS = { TRANSFORMS = {

View file

@ -6,6 +6,7 @@ Every AST node in SQLGlot is represented by a subclass of `Expression`.
This module contains the implementation of all supported `Expression` types. Additionally, This module contains the implementation of all supported `Expression` types. Additionally,
it exposes a number of helper functions, which are mainly used to programmatically build it exposes a number of helper functions, which are mainly used to programmatically build
SQL expressions, such as `sqlglot.expressions.select`. SQL expressions, such as `sqlglot.expressions.select`.
---- ----
""" """
@ -137,6 +138,8 @@ class Expression(metaclass=_Expression):
return field return field
if isinstance(field, (Identifier, Literal, Var)): if isinstance(field, (Identifier, Literal, Var)):
return field.this return field.this
if isinstance(field, (Star, Null)):
return field.name
return "" return ""
@property @property
@ -176,13 +179,11 @@ class Expression(metaclass=_Expression):
return self.text("alias") return self.text("alias")
@property @property
def name(self): def name(self) -> str:
return self.text("this") return self.text("this")
@property @property
def alias_or_name(self): def alias_or_name(self):
if isinstance(self, Null):
return "NULL"
return self.alias or self.name return self.alias or self.name
@property @property
@ -589,12 +590,11 @@ class Expression(metaclass=_Expression):
return load(obj) return load(obj)
if t.TYPE_CHECKING: IntoType = t.Union[
IntoType = t.Union[
str, str,
t.Type[Expression], t.Type[Expression],
t.Collection[t.Union[str, t.Type[Expression]]], t.Collection[t.Union[str, t.Type[Expression]]],
] ]
class Condition(Expression): class Condition(Expression):
@ -939,7 +939,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT # this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "start": False, "increment": False} arg_types = {"this": False, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind): class NotNullColumnConstraint(ColumnConstraintKind):
@ -2390,7 +2390,7 @@ class Star(Expression):
arg_types = {"except": False, "replace": False} arg_types = {"except": False, "replace": False}
@property @property
def name(self): def name(self) -> str:
return "*" return "*"
@property @property
@ -2413,6 +2413,10 @@ class Placeholder(Expression):
class Null(Condition): class Null(Condition):
arg_types: t.Dict[str, t.Any] = {} arg_types: t.Dict[str, t.Any] = {}
@property
def name(self) -> str:
return "NULL"
class Boolean(Condition): class Boolean(Condition):
pass pass
@ -2644,7 +2648,9 @@ class Div(Binary):
class Dot(Binary): class Dot(Binary):
pass @property
def name(self) -> str:
return self.expression.name
class DPipe(Binary): class DPipe(Binary):
@ -2961,7 +2967,7 @@ class Cast(Func):
arg_types = {"this": True, "to": True} arg_types = {"this": True, "to": True}
@property @property
def name(self): def name(self) -> str:
return self.this.name return self.this.name
@property @property
@ -4027,17 +4033,39 @@ def paren(expression) -> Paren:
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: @t.overload
if alias is None: def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
...
@t.overload
def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
...
def to_identifier(name, quoted=None):
"""Builds an identifier.
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
Returns:
The identifier ast node.
"""
if name is None:
return None return None
if isinstance(alias, Identifier):
identifier = alias if isinstance(name, Identifier):
elif isinstance(alias, str): identifier = name
if quoted is None: elif isinstance(name, str):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias) identifier = Identifier(
identifier = Identifier(this=alias, quoted=quoted) this=name,
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
)
else: else:
raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}") raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
return identifier return identifier
@ -4112,20 +4140,31 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return Column(this=column_name, table=table_name, **kwargs) return Column(this=column_name, table=table_name, **kwargs)
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): def alias_(
""" expression: str | Expression,
Create an Alias expression. alias: str | Identifier,
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
**opts,
):
"""Create an Alias expression.
Example: Example:
>>> alias_('foo', 'bar').sql() >>> alias_('foo', 'bar').sql()
'foo AS bar' 'foo AS bar'
>>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql()
'(SELECT 1, 2) AS bar(a, b)'
Args: Args:
expression (str | Expression): the SQL code strings to parse. expression: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is. If an Expression instance is passed, this is used as-is.
alias (str | Identifier): the alias name to use. If the name has alias: the alias name to use. If the name has
special characters it is quoted. special characters it is quoted.
table (bool): create a table alias, default false table: Whether or not to create a table alias, can also be a list of columns.
dialect (str): the dialect used to parse the input expression. quoted: whether or not to quote the alias
dialect: the dialect used to parse the input expression.
**opts: other options to use to parse the input expressions. **opts: other options to use to parse the input expressions.
Returns: Returns:
@ -4135,8 +4174,14 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
alias = to_identifier(alias, quoted=quoted) alias = to_identifier(alias, quoted=quoted)
if table: if table:
expression.set("alias", TableAlias(this=alias)) table_alias = TableAlias(this=alias)
return expression exp.set("alias", table_alias)
if not isinstance(table, bool):
for column in table:
table_alias.append("columns", to_identifier(column, quoted=quoted))
return exp
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import typing as t import typing as t
from sqlglot import exp from sqlglot import exp
@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot") logger = logging.getLogger("sqlglot")
BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
class Generator: class Generator:
""" """
@ -28,7 +31,8 @@ class Generator:
identify (bool): if set to True all identifiers will be delimited by the corresponding identify (bool): if set to True all identifiers will be delimited by the corresponding
character. character.
normalize (bool): if set to True all identifiers will lower cased normalize (bool): if set to True all identifiers will lower cased
escape (str): specifies an escape character. Default: '. string_escape (str): specifies a string escape character. Default: '.
identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2. pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4. indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
@ -85,6 +89,9 @@ class Generator:
# Wrap derived values in parens, usually standard but spark doesn't support it # Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True WRAP_DERIVED_VALUES = True
# Whether or not create function uses an AS before the def.
CREATE_FUNCTION_AS = True
TYPE_MAPPING = { TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.NVARCHAR: "VARCHAR",
@ -154,7 +161,8 @@ class Generator:
"identifier_end", "identifier_end",
"identify", "identify",
"normalize", "normalize",
"escape", "string_escape",
"identifier_escape",
"pad", "pad",
"index_offset", "index_offset",
"unnest_column_only", "unnest_column_only",
@ -167,6 +175,7 @@ class Generator:
"_indent", "_indent",
"_replace_backslash", "_replace_backslash",
"_escaped_quote_end", "_escaped_quote_end",
"_escaped_identifier_end",
"_leading_comma", "_leading_comma",
"_max_text_width", "_max_text_width",
"_comments", "_comments",
@ -183,7 +192,8 @@ class Generator:
identifier_end=None, identifier_end=None,
identify=False, identify=False,
normalize=False, normalize=False,
escape=None, string_escape=None,
identifier_escape=None,
pad=2, pad=2,
indent=2, indent=2,
index_offset=0, index_offset=0,
@ -208,7 +218,8 @@ class Generator:
self.identifier_end = identifier_end or '"' self.identifier_end = identifier_end or '"'
self.identify = identify self.identify = identify
self.normalize = normalize self.normalize = normalize
self.escape = escape or "'" self.string_escape = string_escape or "'"
self.identifier_escape = identifier_escape or '"'
self.pad = pad self.pad = pad
self.index_offset = index_offset self.index_offset = index_offset
self.unnest_column_only = unnest_column_only self.unnest_column_only = unnest_column_only
@ -219,8 +230,9 @@ class Generator:
self.max_unsupported = max_unsupported self.max_unsupported = max_unsupported
self.null_ordering = null_ordering self.null_ordering = null_ordering
self._indent = indent self._indent = indent
self._replace_backslash = self.escape == "\\" self._replace_backslash = self.string_escape == "\\"
self._escaped_quote_end = self.escape + self.quote_end self._escaped_quote_end = self.string_escape + self.quote_end
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma self._leading_comma = leading_comma
self._max_text_width = max_text_width self._max_text_width = max_text_width
self._comments = comments self._comments = comments
@ -441,6 +453,9 @@ class Generator:
def generatedasidentitycolumnconstraint_sql( def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str: ) -> str:
this = ""
if expression.this is not None:
this = " ALWAYS " if expression.this else " BY DEFAULT "
start = expression.args.get("start") start = expression.args.get("start")
start = f"START WITH {start}" if start else "" start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment") increment = expression.args.get("increment")
@ -449,9 +464,7 @@ class Generator:
if start or increment: if start or increment:
sequence_opts = f"{start} {increment}" sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})" sequence_opts = f" ({sequence_opts.strip()})"
return ( return f"GENERATED{this}AS IDENTITY{sequence_opts}"
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
)
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@ -496,7 +509,12 @@ class Generator:
properties_sql = self.sql(properties_exp, "properties") properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else "" begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression") expression_sql = self.sql(expression, "expression")
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" if expression_sql:
expression_sql = f"{begin}{self.sep()}{expression_sql}"
if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
expression_sql = f" AS{expression_sql}"
temporary = " TEMPORARY" if expression.args.get("temporary") else "" temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = ( transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@ -701,6 +719,7 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str: def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name text = expression.name
text = text.lower() if self.normalize else text text = text.lower() if self.normalize else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or self.identify: if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}" text = f"{self.identifier_start}{text}{self.identifier_end}"
return text return text
@ -1121,7 +1140,7 @@ class Generator:
text = expression.this or "" text = expression.this or ""
if expression.is_string: if expression.is_string:
if self._replace_backslash: if self._replace_backslash:
text = text.replace("\\", "\\\\") text = BACKSLASH_RE.sub(r"\\\\", text)
text = text.replace(self.quote_end, self._escaped_quote_end) text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty: if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK) text = text.replace("\n", self.SENTINEL_LINE_BREAK)
@ -1486,9 +1505,16 @@ class Generator:
return f"(SELECT {self.sql(unnest)})" return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str: def interval_sql(self, expression: exp.Interval) -> str:
this = self.sql(expression, "this") this = expression.args.get("this")
this = f" {this}" if this else "" if this:
unit = self.sql(expression, "unit") this = (
f" {this}"
if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
else f" ({this})"
)
else:
this = ""
unit = expression.args.get("unit")
unit = f" {unit}" if unit else "" unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}" return f"INTERVAL{this}{unit}"

View file

@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from sqlglot import Schema, exp, maybe_parse from sqlglot import Schema, exp, maybe_parse
from sqlglot.optimizer import Scope, build_scope, optimize from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.qualify_tables import qualify_tables
@ -38,7 +39,7 @@ def lineage(
sql: str | exp.Expression, sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None, schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns), rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
dialect: DialectType = None, dialect: DialectType = None,
) -> Node: ) -> Node:
"""Build the lineage graph for a column of a SQL query. """Build the lineage graph for a column of a SQL query.

View file

@ -255,12 +255,23 @@ class TypeAnnotator:
for name, source in scope.sources.items(): for name, source in scope.sources.items():
if not isinstance(source, Scope): if not isinstance(source, Scope):
continue continue
if isinstance(source.expression, exp.Values): if isinstance(source.expression, exp.UDTF):
values = []
if isinstance(source.expression, exp.Lateral):
if isinstance(source.expression.this, exp.Explode):
values = [source.expression.this.this]
else:
values = source.expression.expressions[0].expressions
if not values:
continue
selects[name] = { selects[name] = {
alias: column alias: column
for alias, column in zip( for alias, column in zip(
source.expression.alias_column_names, source.expression.alias_column_names,
source.expression.expressions[0].expressions, values,
) )
} }
else: else:
@ -272,7 +283,7 @@ class TypeAnnotator:
source = scope.sources.get(col.table) source = scope.sources.get(col.table)
if isinstance(source, exp.Table): if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col) col.type = self.schema.get_column_type(source, col)
elif source: elif source and col.table in selects:
col.type = selects[col.table][col.name].type col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope # Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression) self._maybe_annotate(scope.expression)

View file

@ -0,0 +1,34 @@
from __future__ import annotations
import typing as t
from sqlglot import exp
def expand_laterals(expression: exp.Expression) -> exp.Expression:
"""
Expand lateral column alias references.
This assumes `qualify_columns` as already run.
Example:
>>> import sqlglot
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
>>> expression = sqlglot.parse_one(sql)
>>> expand_laterals(expression).sql()
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
Args:
expression: expression to optimize
Returns:
optimized expression
"""
for select in expression.find_all(exp.Select):
alias_to_expression: t.Dict[str, exp.Expression] = {}
for projection in select.expressions:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
if isinstance(projection, exp.Alias):
alias_to_expression[projection.alias] = projection.this
return expression

View file

@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.lower_identities import lower_identities
@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema from sqlglot.schema import ensure_schema
@ -22,6 +23,8 @@ RULES = (
qualify_tables, qualify_tables,
isolate_table_selects, isolate_table_selects,
qualify_columns, qualify_columns,
expand_laterals,
validate_qualify_columns,
pushdown_projections, pushdown_projections,
normalize, normalize,
unnest_subqueries, unnest_subqueries,

View file

@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
SELECT_ALL = object() SELECT_ALL = object()
# Selection to use if selection list is empty # Selection to use if selection list is empty
DEFAULT_SELECTION = alias("1", "_") DEFAULT_SELECTION = lambda: alias("1", "_")
def pushdown_projections(expression): def pushdown_projections(expression):
@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
# If there are no remaining selections, just select a single constant # If there are no remaining selections, just select a single constant
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy()) new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)
if removed: if removed:
@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
] ]
if not new_selections: if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy()) new_selections.append(DEFAULT_SELECTION())
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)

View file

@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
if not isinstance(scope.expression, exp.UDTF): if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver) _expand_stars(scope, resolver)
_qualify_outputs(scope) _qualify_outputs(scope)
_check_unknown_tables(scope)
return expression return expression
def validate_qualify_columns(expression):
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
if unqualified_columns:
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
def _pop_table_column_aliases(derived_tables): def _pop_table_column_aliases(derived_tables):
""" """
Remove table column aliases. Remove table column aliases.
@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
if not column_table: if not column_table:
column_table = resolver.get_table(column_name) column_table = resolver.get_table(column_name)
if not scope.is_subquery and not scope.is_udtf:
if column_table is None:
raise OptimizeError(f"Ambiguous column: {column_name}")
# column_table can be a '' because bigquery unnest has no table alias # column_table can be a '' because bigquery unnest has no table alias
if column_table: if column_table:
column.set("table", exp.to_identifier(column_table)) column.set("table", exp.to_identifier(column_table))
@ -231,9 +240,7 @@ def _qualify_columns(scope, resolver):
for column in columns_missing_from_scope: for column in columns_missing_from_scope:
column_table = resolver.get_table(column.name) column_table = resolver.get_table(column.name)
if column_table is None: if column_table:
raise OptimizeError(f"Ambiguous column: {column.name}")
column.set("table", exp.to_identifier(column_table)) column.set("table", exp.to_identifier(column_table))
@ -322,11 +329,6 @@ def _qualify_outputs(scope):
scope.expression.set("expressions", new_selections) scope.expression.set("expressions", new_selections)
def _check_unknown_tables(scope):
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
class _Resolver: class _Resolver:
""" """
Helper for resolving columns. Helper for resolving columns.

View file

@ -2,7 +2,7 @@ import itertools
from sqlglot import alias, exp from sqlglot import alias, exp
from sqlglot.helper import csv_reader from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None): def qualify_tables(expression, db=None, catalog=None, schema=None):
@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
""" """
sequence = itertools.count() sequence = itertools.count()
next_name = lambda: f"_q_{next(sequence)}"
for scope in traverse_scope(expression): for scope in traverse_scope(expression):
for derived_table in scope.ctes + scope.derived_tables: for derived_table in scope.ctes + scope.derived_tables:
if not derived_table.args.get("alias"): if not derived_table.args.get("alias"):
@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace( source = source.replace(
alias( alias(
source.copy(), source.copy(),
source.this if identifier else f"_q_{next(sequence)}", source.this if identifier else next_name(),
table=True, table=True,
) )
) )
@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
schema.add_table( schema.add_table(
source, {k: type(v).__name__ for k, v in zip(header, columns)} source, {k: type(v).__name__ for k, v in zip(header, columns)}
) )
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
udtf.set("alias", table_alias)
if not table_alias.name:
table_alias.set("this", next_name())
return expression return expression

View file

@ -237,6 +237,8 @@ class Scope:
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
if ( if (
not ancestor not ancestor
# Window functions can have an ORDER BY clause
or not isinstance(ancestor.parent, exp.Select)
or column.table or column.table
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
): ):
@ -479,7 +481,7 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Union): elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope) yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.UDTF): elif isinstance(scope.expression, exp.UDTF):
pass _set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery): elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope) yield from _traverse_subqueries(scope)
else: else:
@ -509,6 +511,22 @@ def _traverse_union(scope):
scope.union_scopes = [left, right] scope.union_scopes = [left, right]
def _set_udtf_scope(scope):
parent = scope.expression.parent
from_ = parent.args.get("from")
if not from_:
return
for table in from_.expressions:
if isinstance(table, exp.Table):
scope.tables.append(table)
elif isinstance(table, exp.Subquery):
scope.subqueries.append(table)
_add_table_sources(scope)
_traverse_subqueries(scope)
def _traverse_derived_tables(derived_tables, scope, scope_type): def _traverse_derived_tables(derived_tables, scope, scope_type):
sources = {} sources = {}
is_cte = scope_type == ScopeType.CTE is_cte = scope_type == ScopeType.CTE

View file

@ -194,6 +194,7 @@ class Parser(metaclass=_Parser):
TokenType.INTERVAL, TokenType.INTERVAL,
TokenType.LAZY, TokenType.LAZY,
TokenType.LEADING, TokenType.LEADING,
TokenType.LEFT,
TokenType.LOCAL, TokenType.LOCAL,
TokenType.MATERIALIZED, TokenType.MATERIALIZED,
TokenType.MERGE, TokenType.MERGE,
@ -208,6 +209,7 @@ class Parser(metaclass=_Parser):
TokenType.PRECEDING, TokenType.PRECEDING,
TokenType.RANGE, TokenType.RANGE,
TokenType.REFERENCES, TokenType.REFERENCES,
TokenType.RIGHT,
TokenType.ROW, TokenType.ROW,
TokenType.ROWS, TokenType.ROWS,
TokenType.SCHEMA, TokenType.SCHEMA,
@ -237,8 +239,10 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY, TokenType.APPLY,
TokenType.LEFT,
TokenType.NATURAL, TokenType.NATURAL,
TokenType.OFFSET, TokenType.OFFSET,
TokenType.RIGHT,
TokenType.WINDOW, TokenType.WINDOW,
} }
@ -258,6 +262,8 @@ class Parser(metaclass=_Parser):
TokenType.IDENTIFIER, TokenType.IDENTIFIER,
TokenType.INDEX, TokenType.INDEX,
TokenType.ISNULL, TokenType.ISNULL,
TokenType.ILIKE,
TokenType.LIKE,
TokenType.MERGE, TokenType.MERGE,
TokenType.OFFSET, TokenType.OFFSET,
TokenType.PRIMARY_KEY, TokenType.PRIMARY_KEY,
@ -971,7 +977,8 @@ class Parser(metaclass=_Parser):
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type) this = self._parse_user_defined_function(kind=create_token.token_type)
properties = self._parse_properties() properties = self._parse_properties()
if self._match(TokenType.ALIAS):
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN) begin = self._match(TokenType.BEGIN)
return_ = self._match_text_seq("RETURN") return_ = self._match_text_seq("RETURN")
expression = self._parse_statement() expression = self._parse_statement()
@ -2163,7 +2170,9 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]: ) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT): if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN) limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) limit_exp = self.expression(
exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
)
if limit_paren: if limit_paren:
self._match_r_paren() self._match_r_paren()
@ -2740,7 +2749,22 @@ class Parser(metaclass=_Parser):
kind: exp.Expression kind: exp.Expression
if self._match(TokenType.AUTO_INCREMENT): if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
start = None
increment = None
if self._match(TokenType.L_PAREN, advance=False):
args = self._parse_wrapped_csv(self._parse_bitwise)
start = seq_get(args, 0)
increment = seq_get(args, 1)
elif self._match_text_seq("START"):
start = self._parse_bitwise()
self._match_text_seq("INCREMENT")
increment = self._parse_bitwise()
if start and increment:
kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
else:
kind = exp.AutoIncrementColumnConstraint() kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK): elif self._match(TokenType.CHECK):
constraint = self._parse_wrapped(self._parse_conjunction) constraint = self._parse_wrapped(self._parse_conjunction)
@ -3294,8 +3318,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.EXCEPT): if not self._match(TokenType.EXCEPT):
return None return None
if self._match(TokenType.L_PAREN, advance=False): if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_id_vars() return self._parse_wrapped_csv(self._parse_column)
return self._parse_csv(self._parse_id_var) return self._parse_csv(self._parse_column)
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE): if not self._match(TokenType.REPLACE):
@ -3442,7 +3466,7 @@ class Parser(metaclass=_Parser):
def _parse_alter(self) -> t.Optional[exp.Expression]: def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE): if not self._match(TokenType.TABLE):
return None return self._parse_as_command(self._prev)
exists = self._parse_exists() exists = self._parse_exists()
this = self._parse_table(schema=True) this = self._parse_table(schema=True)

View file

@ -357,7 +357,8 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
klass._ESCAPES = set(klass.ESCAPES) klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
klass._COMMENTS = dict( klass._COMMENTS = dict(
(comment, None) if isinstance(comment, str) else (comment[0], comment[1]) (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
for comment in klass.COMMENTS for comment in klass.COMMENTS
@ -429,9 +430,13 @@ class Tokenizer(metaclass=_Tokenizer):
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
ESCAPES = ["'"] STRING_ESCAPES = ["'"]
_ESCAPES: t.Set[str] = set() _STRING_ESCAPES: t.Set[str] = set()
IDENTIFIER_ESCAPES = ['"']
_IDENTIFIER_ESCAPES: t.Set[str] = set()
KEYWORDS = { KEYWORDS = {
**{ **{
@ -469,6 +474,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ASC": TokenType.ASC, "ASC": TokenType.ASC,
"AS": TokenType.ALIAS, "AS": TokenType.ALIAS,
"AT TIME ZONE": TokenType.AT_TIME_ZONE, "AT TIME ZONE": TokenType.AT_TIME_ZONE,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN, "BEGIN": TokenType.BEGIN,
"BETWEEN": TokenType.BETWEEN, "BETWEEN": TokenType.BETWEEN,
@ -691,6 +697,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER VIEW": TokenType.COMMAND, "ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND, "CALL": TokenType.COMMAND,
"COPY": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND, "OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND, "PREPARE": TokenType.COMMAND,
@ -744,7 +751,7 @@ class Tokenizer(metaclass=_Tokenizer):
) )
def __init__(self) -> None: def __init__(self) -> None:
self._replace_backslash = "\\" in self._ESCAPES self._replace_backslash = "\\" in self._STRING_ESCAPES
self.reset() self.reset()
def reset(self) -> None: def reset(self) -> None:
@ -1046,12 +1053,25 @@ class Tokenizer(metaclass=_Tokenizer):
return True return True
def _scan_identifier(self, identifier_end: str) -> None: def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end: text = ""
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
while True:
if self._end: if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}") raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
self._advance() self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
text += identifier_end # type: ignore
self._advance() self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1]) continue
break
text += self._char # type: ignore
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None: def _scan_var(self) -> None:
while True: while True:
@ -1072,9 +1092,9 @@ class Tokenizer(metaclass=_Tokenizer):
while True: while True:
if ( if (
self._char in self._ESCAPES self._char in self._STRING_ESCAPES
and self._peek and self._peek
and (self._peek == delimiter or self._peek in self._ESCAPES) and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
): ):
text += self._peek text += self._peek
self._advance(2) self._advance(2)

Some files were not shown because too many files have changed in this diff Show more