Merging upstream version 11.0.1.
Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
parent
fdac67ef7f
commit
ba0f3f0bfa
112 changed files with 126100 additions and 230 deletions
14
.github/workflows/python-package.yml
vendored
14
.github/workflows/python-package.yml
vendored
|
@ -22,5 +22,17 @@ jobs:
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
make install-dev
|
make install-dev
|
||||||
- name: Run checks (linter, code style, tests)
|
- name: Run checks (linter, code style, tests)
|
||||||
|
run: make check
|
||||||
|
- name: Update documentation
|
||||||
run: |
|
run: |
|
||||||
make check
|
make docs
|
||||||
|
git add docs
|
||||||
|
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config --local user.name "github-actions[bot]"
|
||||||
|
git commit -m "CI: Auto-generated documentation" -a | exit 0
|
||||||
|
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' }}
|
||||||
|
- name: Push changes
|
||||||
|
if: ${{ matrix.python-version == '3.10' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') }}
|
||||||
|
uses: ad-m/github-push-action@master
|
||||||
|
with:
|
||||||
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
33
CHANGELOG.md
33
CHANGELOG.md
|
@ -1,6 +1,39 @@
|
||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
v11.0.0
|
||||||
|
------
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
|
||||||
|
- Breaking: Renamed ESCAPES to STRING_ESCAPES in the Tokenizer class.
|
||||||
|
|
||||||
|
- New: Deployed pdoc documentation page.
|
||||||
|
|
||||||
|
- New: Add support for read locking using the FOR UPDATE/SHARE syntax (e.g. MySQL).
|
||||||
|
|
||||||
|
- New: Added support for CASCADE, SET NULL and SET DEFAULT constraints.
|
||||||
|
|
||||||
|
- New: Added "cast" expression helper.
|
||||||
|
|
||||||
|
- New: Add support for transpiling Postgres GENERATE_SERIES into Presto SEQUENCE.
|
||||||
|
|
||||||
|
- Improvement: Fix tokenizing of identifier escapes.
|
||||||
|
|
||||||
|
- Improvement: Fix eliminate_subqueries [bug](https://github.com/tobymao/sqlglot/commit/b5df65e3fb5ee1ebc3cbab64b6d89598cf47a10b) related to unions.
|
||||||
|
|
||||||
|
- Improvement: IFNULL is now transpiled to COALESCE by default for every dialect.
|
||||||
|
|
||||||
|
- Improvement: Refactored the way properties are handled. Now it's easier to add them and specify their position in a SQL expression.
|
||||||
|
|
||||||
|
- Improvement: Fixed alias quoting bug.
|
||||||
|
|
||||||
|
- Improvement: Fixed CUBE / ROLLUP / GROUPING SETS parsing and generation.
|
||||||
|
|
||||||
|
- Improvement: Fixed get_or_raise Dialect/t.Type[Dialect] argument bug.
|
||||||
|
|
||||||
|
- Improvement: Improved python type hints.
|
||||||
|
|
||||||
v10.6.0
|
v10.6.0
|
||||||
------
|
------
|
||||||
|
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -18,7 +18,7 @@ style:
|
||||||
check: style test
|
check: style test
|
||||||
|
|
||||||
docs:
|
docs:
|
||||||
python pdoc/cli.py -o pdoc/docs
|
python pdoc/cli.py -o docs
|
||||||
|
|
||||||
docs-serve:
|
docs-serve:
|
||||||
python pdoc/cli.py
|
python pdoc/cli.py --port 8002
|
||||||
|
|
1
docs/CNAME
Normal file
1
docs/CNAME
Normal file
|
@ -0,0 +1 @@
|
||||||
|
sqlglot.com
|
7
docs/index.html
Normal file
7
docs/index.html
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta http-equiv="refresh" content="0; url=./sqlglot.html"/>
|
||||||
|
</head>
|
||||||
|
</html>
|
46
docs/search.js
Normal file
46
docs/search.js
Normal file
File diff suppressed because one or more lines are too long
1226
docs/sqlglot.html
Normal file
1226
docs/sqlglot.html
Normal file
File diff suppressed because one or more lines are too long
506
docs/sqlglot/dataframe.html
Normal file
506
docs/sqlglot/dataframe.html
Normal file
File diff suppressed because one or more lines are too long
4953
docs/sqlglot/dataframe/sql.html
Normal file
4953
docs/sqlglot/dataframe/sql.html
Normal file
File diff suppressed because one or more lines are too long
400
docs/sqlglot/dialects.html
Normal file
400
docs/sqlglot/dialects.html
Normal file
File diff suppressed because one or more lines are too long
1434
docs/sqlglot/dialects/bigquery.html
Normal file
1434
docs/sqlglot/dialects/bigquery.html
Normal file
File diff suppressed because one or more lines are too long
1077
docs/sqlglot/dialects/clickhouse.html
Normal file
1077
docs/sqlglot/dialects/clickhouse.html
Normal file
File diff suppressed because one or more lines are too long
704
docs/sqlglot/dialects/databricks.html
Normal file
704
docs/sqlglot/dialects/databricks.html
Normal file
File diff suppressed because one or more lines are too long
2134
docs/sqlglot/dialects/dialect.html
Normal file
2134
docs/sqlglot/dialects/dialect.html
Normal file
File diff suppressed because one or more lines are too long
1088
docs/sqlglot/dialects/drill.html
Normal file
1088
docs/sqlglot/dialects/drill.html
Normal file
File diff suppressed because one or more lines are too long
1028
docs/sqlglot/dialects/duckdb.html
Normal file
1028
docs/sqlglot/dialects/duckdb.html
Normal file
File diff suppressed because one or more lines are too long
1461
docs/sqlglot/dialects/hive.html
Normal file
1461
docs/sqlglot/dialects/hive.html
Normal file
File diff suppressed because one or more lines are too long
2149
docs/sqlglot/dialects/mysql.html
Normal file
2149
docs/sqlglot/dialects/mysql.html
Normal file
File diff suppressed because one or more lines are too long
1052
docs/sqlglot/dialects/oracle.html
Normal file
1052
docs/sqlglot/dialects/oracle.html
Normal file
File diff suppressed because one or more lines are too long
1245
docs/sqlglot/dialects/postgres.html
Normal file
1245
docs/sqlglot/dialects/postgres.html
Normal file
File diff suppressed because one or more lines are too long
1255
docs/sqlglot/dialects/presto.html
Normal file
1255
docs/sqlglot/dialects/presto.html
Normal file
File diff suppressed because one or more lines are too long
1175
docs/sqlglot/dialects/redshift.html
Normal file
1175
docs/sqlglot/dialects/redshift.html
Normal file
File diff suppressed because one or more lines are too long
1528
docs/sqlglot/dialects/snowflake.html
Normal file
1528
docs/sqlglot/dialects/snowflake.html
Normal file
File diff suppressed because one or more lines are too long
1136
docs/sqlglot/dialects/spark.html
Normal file
1136
docs/sqlglot/dialects/spark.html
Normal file
File diff suppressed because one or more lines are too long
918
docs/sqlglot/dialects/sqlite.html
Normal file
918
docs/sqlglot/dialects/sqlite.html
Normal file
File diff suppressed because one or more lines are too long
658
docs/sqlglot/dialects/starrocks.html
Normal file
658
docs/sqlglot/dialects/starrocks.html
Normal file
File diff suppressed because one or more lines are too long
704
docs/sqlglot/dialects/tableau.html
Normal file
704
docs/sqlglot/dialects/tableau.html
Normal file
File diff suppressed because one or more lines are too long
960
docs/sqlglot/dialects/teradata.html
Normal file
960
docs/sqlglot/dialects/teradata.html
Normal file
File diff suppressed because one or more lines are too long
653
docs/sqlglot/dialects/trino.html
Normal file
653
docs/sqlglot/dialects/trino.html
Normal file
File diff suppressed because one or more lines are too long
1772
docs/sqlglot/dialects/tsql.html
Normal file
1772
docs/sqlglot/dialects/tsql.html
Normal file
File diff suppressed because one or more lines are too long
1560
docs/sqlglot/diff.html
Normal file
1560
docs/sqlglot/diff.html
Normal file
File diff suppressed because one or more lines are too long
877
docs/sqlglot/errors.html
Normal file
877
docs/sqlglot/errors.html
Normal file
File diff suppressed because one or more lines are too long
694
docs/sqlglot/executor.html
Normal file
694
docs/sqlglot/executor.html
Normal file
File diff suppressed because one or more lines are too long
715
docs/sqlglot/executor/context.html
Normal file
715
docs/sqlglot/executor/context.html
Normal file
File diff suppressed because one or more lines are too long
717
docs/sqlglot/executor/env.html
Normal file
717
docs/sqlglot/executor/env.html
Normal file
File diff suppressed because one or more lines are too long
2130
docs/sqlglot/executor/python.html
Normal file
2130
docs/sqlglot/executor/python.html
Normal file
File diff suppressed because one or more lines are too long
802
docs/sqlglot/executor/table.html
Normal file
802
docs/sqlglot/executor/table.html
Normal file
File diff suppressed because one or more lines are too long
39484
docs/sqlglot/expressions.html
Normal file
39484
docs/sqlglot/expressions.html
Normal file
File diff suppressed because one or more lines are too long
9855
docs/sqlglot/generator.html
Normal file
9855
docs/sqlglot/generator.html
Normal file
File diff suppressed because one or more lines are too long
1651
docs/sqlglot/helper.html
Normal file
1651
docs/sqlglot/helper.html
Normal file
File diff suppressed because one or more lines are too long
931
docs/sqlglot/lineage.html
Normal file
931
docs/sqlglot/lineage.html
Normal file
File diff suppressed because one or more lines are too long
264
docs/sqlglot/optimizer.html
Normal file
264
docs/sqlglot/optimizer.html
Normal file
File diff suppressed because one or more lines are too long
1179
docs/sqlglot/optimizer/annotate_types.html
Normal file
1179
docs/sqlglot/optimizer/annotate_types.html
Normal file
File diff suppressed because one or more lines are too long
445
docs/sqlglot/optimizer/canonicalize.html
Normal file
445
docs/sqlglot/optimizer/canonicalize.html
Normal file
File diff suppressed because one or more lines are too long
371
docs/sqlglot/optimizer/eliminate_ctes.html
Normal file
371
docs/sqlglot/optimizer/eliminate_ctes.html
Normal file
File diff suppressed because one or more lines are too long
610
docs/sqlglot/optimizer/eliminate_joins.html
Normal file
610
docs/sqlglot/optimizer/eliminate_joins.html
Normal file
File diff suppressed because one or more lines are too long
582
docs/sqlglot/optimizer/eliminate_subqueries.html
Normal file
582
docs/sqlglot/optimizer/eliminate_subqueries.html
Normal file
File diff suppressed because one or more lines are too long
353
docs/sqlglot/optimizer/expand_laterals.html
Normal file
353
docs/sqlglot/optimizer/expand_laterals.html
Normal file
File diff suppressed because one or more lines are too long
321
docs/sqlglot/optimizer/expand_multi_table_selects.html
Normal file
321
docs/sqlglot/optimizer/expand_multi_table_selects.html
Normal file
File diff suppressed because one or more lines are too long
317
docs/sqlglot/optimizer/isolate_table_selects.html
Normal file
317
docs/sqlglot/optimizer/isolate_table_selects.html
Normal file
File diff suppressed because one or more lines are too long
430
docs/sqlglot/optimizer/lower_identities.html
Normal file
430
docs/sqlglot/optimizer/lower_identities.html
Normal file
File diff suppressed because one or more lines are too long
794
docs/sqlglot/optimizer/merge_subqueries.html
Normal file
794
docs/sqlglot/optimizer/merge_subqueries.html
Normal file
File diff suppressed because one or more lines are too long
585
docs/sqlglot/optimizer/normalize.html
Normal file
585
docs/sqlglot/optimizer/normalize.html
Normal file
File diff suppressed because one or more lines are too long
489
docs/sqlglot/optimizer/optimize_joins.html
Normal file
489
docs/sqlglot/optimizer/optimize_joins.html
Normal file
File diff suppressed because one or more lines are too long
401
docs/sqlglot/optimizer/optimizer.html
Normal file
401
docs/sqlglot/optimizer/optimizer.html
Normal file
File diff suppressed because one or more lines are too long
773
docs/sqlglot/optimizer/pushdown_predicates.html
Normal file
773
docs/sqlglot/optimizer/pushdown_predicates.html
Normal file
File diff suppressed because one or more lines are too long
477
docs/sqlglot/optimizer/pushdown_projections.html
Normal file
477
docs/sqlglot/optimizer/pushdown_projections.html
Normal file
File diff suppressed because one or more lines are too long
804
docs/sqlglot/optimizer/qualify_columns.html
Normal file
804
docs/sqlglot/optimizer/qualify_columns.html
Normal file
File diff suppressed because one or more lines are too long
427
docs/sqlglot/optimizer/qualify_tables.html
Normal file
427
docs/sqlglot/optimizer/qualify_tables.html
Normal file
File diff suppressed because one or more lines are too long
2512
docs/sqlglot/optimizer/scope.html
Normal file
2512
docs/sqlglot/optimizer/scope.html
Normal file
File diff suppressed because one or more lines are too long
1428
docs/sqlglot/optimizer/simplify.html
Normal file
1428
docs/sqlglot/optimizer/simplify.html
Normal file
File diff suppressed because one or more lines are too long
835
docs/sqlglot/optimizer/unnest_subqueries.html
Normal file
835
docs/sqlglot/optimizer/unnest_subqueries.html
Normal file
File diff suppressed because one or more lines are too long
8049
docs/sqlglot/parser.html
Normal file
8049
docs/sqlglot/parser.html
Normal file
File diff suppressed because one or more lines are too long
1995
docs/sqlglot/planner.html
Normal file
1995
docs/sqlglot/planner.html
Normal file
File diff suppressed because one or more lines are too long
1624
docs/sqlglot/schema.html
Normal file
1624
docs/sqlglot/schema.html
Normal file
File diff suppressed because one or more lines are too long
408
docs/sqlglot/serde.html
Normal file
408
docs/sqlglot/serde.html
Normal file
File diff suppressed because one or more lines are too long
385
docs/sqlglot/time.html
Normal file
385
docs/sqlglot/time.html
Normal file
File diff suppressed because one or more lines are too long
6712
docs/sqlglot/tokens.html
Normal file
6712
docs/sqlglot/tokens.html
Normal file
File diff suppressed because one or more lines are too long
667
docs/sqlglot/transforms.html
Normal file
667
docs/sqlglot/transforms.html
Normal file
File diff suppressed because one or more lines are too long
479
docs/sqlglot/trie.html
Normal file
479
docs/sqlglot/trie.html
Normal file
File diff suppressed because one or more lines are too long
|
@ -26,9 +26,9 @@ if __name__ == "__main__":
|
||||||
opts = parser.parse_args()
|
opts = parser.parse_args()
|
||||||
opts.docformat = "google"
|
opts.docformat = "google"
|
||||||
opts.modules = ["sqlglot"]
|
opts.modules = ["sqlglot"]
|
||||||
opts.footer_text = "Copyright (c) 2022 Toby Mao"
|
opts.footer_text = "Copyright (c) 2023 Toby Mao"
|
||||||
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
|
opts.template_directory = Path(__file__).parent.joinpath("templates").absolute()
|
||||||
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"]
|
opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/tree/main/sqlglot/"]
|
||||||
|
|
||||||
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
|
with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}):
|
||||||
cli()
|
cli()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../README.md
|
.. include:: ../README.md
|
||||||
|
|
||||||
----
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ if t.TYPE_CHECKING:
|
||||||
T = t.TypeVar("T", bound=Expression)
|
T = t.TypeVar("T", bound=Expression)
|
||||||
|
|
||||||
|
|
||||||
__version__ = "10.6.3"
|
__version__ = "11.0.1"
|
||||||
|
|
||||||
pretty = False
|
pretty = False
|
||||||
"""Whether to format generated SQL by default."""
|
"""Whether to format generated SQL by default."""
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens, transforms
|
from sqlglot import exp, generator, parser, tokens, transforms
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
Dialect,
|
Dialect,
|
||||||
|
@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import (
|
||||||
from sqlglot.helper import seq_get
|
from sqlglot.helper import seq_get
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
E = t.TypeVar("E", bound=exp.Expression)
|
||||||
|
|
||||||
def _date_add(expression_class):
|
|
||||||
|
def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]:
|
||||||
def func(args):
|
def func(args):
|
||||||
interval = seq_get(args, 1)
|
interval = seq_get(args, 1)
|
||||||
return expression_class(
|
return expression_class(
|
||||||
|
@ -27,26 +31,26 @@ def _date_add(expression_class):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def _date_trunc(args):
|
def _date_trunc(args: t.Sequence) -> exp.Expression:
|
||||||
unit = seq_get(args, 1)
|
unit = seq_get(args, 1)
|
||||||
if isinstance(unit, exp.Column):
|
if isinstance(unit, exp.Column):
|
||||||
unit = exp.Var(this=unit.name)
|
unit = exp.Var(this=unit.name)
|
||||||
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
|
||||||
|
|
||||||
|
|
||||||
def _date_add_sql(data_type, kind):
|
def _date_add_sql(
|
||||||
|
data_type: str, kind: str
|
||||||
|
) -> t.Callable[[generator.Generator, exp.Expression], str]:
|
||||||
def func(self, expression):
|
def func(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
unit = self.sql(expression, "unit") or "'day'"
|
return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})"
|
||||||
expression = self.sql(expression, "expression")
|
|
||||||
return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
|
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def _derived_table_values_to_unnest(self, expression):
|
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
|
||||||
if not isinstance(expression.unnest().parent, exp.From):
|
if not isinstance(expression.unnest().parent, exp.From):
|
||||||
expression = transforms.remove_precision_parameterized_types(expression)
|
expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
|
||||||
return self.values_sql(expression)
|
return self.values_sql(expression)
|
||||||
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
|
||||||
structs = []
|
structs = []
|
||||||
|
@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression):
|
||||||
return self.unnest_sql(unnest_exp)
|
return self.unnest_sql(unnest_exp)
|
||||||
|
|
||||||
|
|
||||||
def _returnsproperty_sql(self, expression):
|
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
|
||||||
this = expression.this
|
this = expression.this
|
||||||
if isinstance(this, exp.Schema):
|
if isinstance(this, exp.Schema):
|
||||||
this = f"{this.this} <{self.expressions(this)}>"
|
this = f"{this.this} <{self.expressions(this)}>"
|
||||||
|
@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression):
|
||||||
return f"RETURNS {this}"
|
return f"RETURNS {this}"
|
||||||
|
|
||||||
|
|
||||||
def _create_sql(self, expression):
|
def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
|
||||||
kind = expression.args.get("kind")
|
kind = expression.args["kind"]
|
||||||
returns = expression.find(exp.ReturnsProperty)
|
returns = expression.find(exp.ReturnsProperty)
|
||||||
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
|
if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
|
@ -89,6 +93,29 @@ def _create_sql(self, expression):
|
||||||
return self.create_sql(expression)
|
return self.create_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
|
def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""Remove references to unnest table aliases since bigquery doesn't allow them.
|
||||||
|
|
||||||
|
These are added by the optimizer's qualify_column step.
|
||||||
|
"""
|
||||||
|
if isinstance(expression, exp.Select):
|
||||||
|
unnests = {
|
||||||
|
unnest.alias
|
||||||
|
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
|
||||||
|
if isinstance(unnest, exp.Unnest) and unnest.alias
|
||||||
|
}
|
||||||
|
|
||||||
|
if unnests:
|
||||||
|
expression = expression.copy()
|
||||||
|
|
||||||
|
for select in expression.expressions:
|
||||||
|
for column in select.find_all(exp.Column):
|
||||||
|
if column.table in unnests:
|
||||||
|
column.set("table", None)
|
||||||
|
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
class BigQuery(Dialect):
|
class BigQuery(Dialect):
|
||||||
unnest_column_only = True
|
unnest_column_only = True
|
||||||
time_mapping = {
|
time_mapping = {
|
||||||
|
@ -110,7 +137,7 @@ class BigQuery(Dialect):
|
||||||
]
|
]
|
||||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPES = ["\\"]
|
STRING_ESCAPES = ["\\"]
|
||||||
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
HEX_STRINGS = [("0x", ""), ("0X", "")]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
|
@ -190,6 +217,9 @@ class BigQuery(Dialect):
|
||||||
exp.GroupConcat: rename_func("STRING_AGG"),
|
exp.GroupConcat: rename_func("STRING_AGG"),
|
||||||
exp.ILike: no_ilike_sql,
|
exp.ILike: no_ilike_sql,
|
||||||
exp.IntDiv: rename_func("DIV"),
|
exp.IntDiv: rename_func("DIV"),
|
||||||
|
exp.Select: transforms.preprocess(
|
||||||
|
[_unqualify_unnest], transforms.delegate("select_sql")
|
||||||
|
),
|
||||||
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
|
||||||
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
|
||||||
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
exp.TimeSub: _date_add_sql("TIME", "SUB"),
|
||||||
|
|
|
@ -9,7 +9,7 @@ from sqlglot.parser import parse_var_map
|
||||||
from sqlglot.tokens import TokenType
|
from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
|
|
||||||
def _lower_func(sql):
|
def _lower_func(sql: str) -> str:
|
||||||
index = sql.index("(")
|
index = sql.index("(")
|
||||||
return sql[:index].lower() + sql[index:]
|
return sql[:index].lower() + sql[index:]
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,8 @@ from sqlglot.time import format_time
|
||||||
from sqlglot.tokens import Tokenizer
|
from sqlglot.tokens import Tokenizer
|
||||||
from sqlglot.trie import new_trie
|
from sqlglot.trie import new_trie
|
||||||
|
|
||||||
|
E = t.TypeVar("E", bound=exp.Expression)
|
||||||
|
|
||||||
|
|
||||||
class Dialects(str, Enum):
|
class Dialects(str, Enum):
|
||||||
DIALECT = ""
|
DIALECT = ""
|
||||||
|
@ -37,14 +39,16 @@ class Dialects(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class _Dialect(type):
|
class _Dialect(type):
|
||||||
classes: t.Dict[str, Dialect] = {}
|
classes: t.Dict[str, t.Type[Dialect]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __getitem__(cls, key):
|
def __getitem__(cls, key: str) -> t.Type[Dialect]:
|
||||||
return cls.classes[key]
|
return cls.classes[key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, key, default=None):
|
def get(
|
||||||
|
cls, key: str, default: t.Optional[t.Type[Dialect]] = None
|
||||||
|
) -> t.Optional[t.Type[Dialect]]:
|
||||||
return cls.classes.get(key, default)
|
return cls.classes.get(key, default)
|
||||||
|
|
||||||
def __new__(cls, clsname, bases, attrs):
|
def __new__(cls, clsname, bases, attrs):
|
||||||
|
@ -119,7 +123,7 @@ class Dialect(metaclass=_Dialect):
|
||||||
generator_class = None
|
generator_class = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_or_raise(cls, dialect):
|
def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
|
||||||
if not dialect:
|
if not dialect:
|
||||||
return cls
|
return cls
|
||||||
if isinstance(dialect, _Dialect):
|
if isinstance(dialect, _Dialect):
|
||||||
|
@ -134,7 +138,9 @@ class Dialect(metaclass=_Dialect):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def format_time(cls, expression):
|
def format_time(
|
||||||
|
cls, expression: t.Optional[str | exp.Expression]
|
||||||
|
) -> t.Optional[exp.Expression]:
|
||||||
if isinstance(expression, str):
|
if isinstance(expression, str):
|
||||||
return exp.Literal.string(
|
return exp.Literal.string(
|
||||||
format_time(
|
format_time(
|
||||||
|
@ -153,26 +159,28 @@ class Dialect(metaclass=_Dialect):
|
||||||
)
|
)
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
def parse(self, sql, **opts):
|
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
|
||||||
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
|
||||||
|
|
||||||
def parse_into(self, expression_type, sql, **opts):
|
def parse_into(
|
||||||
|
self, expression_type: exp.IntoType, sql: str, **opts
|
||||||
|
) -> t.List[t.Optional[exp.Expression]]:
|
||||||
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
|
return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
|
||||||
|
|
||||||
def generate(self, expression, **opts):
|
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
|
||||||
return self.generator(**opts).generate(expression)
|
return self.generator(**opts).generate(expression)
|
||||||
|
|
||||||
def transpile(self, code, **opts):
|
def transpile(self, sql: str, **opts) -> t.List[str]:
|
||||||
return self.generate(self.parse(code), **opts)
|
return [self.generate(expression, **opts) for expression in self.parse(sql)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokenizer(self):
|
def tokenizer(self) -> Tokenizer:
|
||||||
if not hasattr(self, "_tokenizer"):
|
if not hasattr(self, "_tokenizer"):
|
||||||
self._tokenizer = self.tokenizer_class()
|
self._tokenizer = self.tokenizer_class() # type: ignore
|
||||||
return self._tokenizer
|
return self._tokenizer
|
||||||
|
|
||||||
def parser(self, **opts):
|
def parser(self, **opts) -> Parser:
|
||||||
return self.parser_class(
|
return self.parser_class( # type: ignore
|
||||||
**{
|
**{
|
||||||
"index_offset": self.index_offset,
|
"index_offset": self.index_offset,
|
||||||
"unnest_column_only": self.unnest_column_only,
|
"unnest_column_only": self.unnest_column_only,
|
||||||
|
@ -182,14 +190,15 @@ class Dialect(metaclass=_Dialect):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def generator(self, **opts):
|
def generator(self, **opts) -> Generator:
|
||||||
return self.generator_class(
|
return self.generator_class( # type: ignore
|
||||||
**{
|
**{
|
||||||
"quote_start": self.quote_start,
|
"quote_start": self.quote_start,
|
||||||
"quote_end": self.quote_end,
|
"quote_end": self.quote_end,
|
||||||
"identifier_start": self.identifier_start,
|
"identifier_start": self.identifier_start,
|
||||||
"identifier_end": self.identifier_end,
|
"identifier_end": self.identifier_end,
|
||||||
"escape": self.tokenizer_class.ESCAPES[0],
|
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
|
||||||
|
"identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
|
||||||
"index_offset": self.index_offset,
|
"index_offset": self.index_offset,
|
||||||
"time_mapping": self.inverse_time_mapping,
|
"time_mapping": self.inverse_time_mapping,
|
||||||
"time_trie": self.inverse_time_trie,
|
"time_trie": self.inverse_time_trie,
|
||||||
|
@ -202,11 +211,10 @@ class Dialect(metaclass=_Dialect):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
||||||
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
|
|
||||||
|
|
||||||
|
|
||||||
def rename_func(name):
|
def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]:
|
||||||
def _rename(self, expression):
|
def _rename(self, expression):
|
||||||
args = flatten(expression.args.values())
|
args = flatten(expression.args.values())
|
||||||
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
return f"{self.normalize_func(name)}({self.format_args(*args)})"
|
||||||
|
@ -214,32 +222,34 @@ def rename_func(name):
|
||||||
return _rename
|
return _rename
|
||||||
|
|
||||||
|
|
||||||
def approx_count_distinct_sql(self, expression):
|
def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str:
|
||||||
if expression.args.get("accuracy"):
|
if expression.args.get("accuracy"):
|
||||||
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
|
||||||
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})"
|
||||||
|
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self: Generator, expression: exp.If) -> str:
|
||||||
expressions = self.format_args(
|
expressions = self.format_args(
|
||||||
expression.this, expression.args.get("true"), expression.args.get("false")
|
expression.this, expression.args.get("true"), expression.args.get("false")
|
||||||
)
|
)
|
||||||
return f"IF({expressions})"
|
return f"IF({expressions})"
|
||||||
|
|
||||||
|
|
||||||
def arrow_json_extract_sql(self, expression):
|
def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
|
||||||
return self.binary(expression, "->")
|
return self.binary(expression, "->")
|
||||||
|
|
||||||
|
|
||||||
def arrow_json_extract_scalar_sql(self, expression):
|
def arrow_json_extract_scalar_sql(
|
||||||
|
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
|
||||||
|
) -> str:
|
||||||
return self.binary(expression, "->>")
|
return self.binary(expression, "->>")
|
||||||
|
|
||||||
|
|
||||||
def inline_array_sql(self, expression):
|
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
|
||||||
return f"[{self.expressions(expression)}]"
|
return f"[{self.expressions(expression)}]"
|
||||||
|
|
||||||
|
|
||||||
def no_ilike_sql(self, expression):
|
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
|
||||||
return self.like_sql(
|
return self.like_sql(
|
||||||
exp.Like(
|
exp.Like(
|
||||||
this=exp.Lower(this=expression.this),
|
this=exp.Lower(this=expression.this),
|
||||||
|
@ -248,44 +258,44 @@ def no_ilike_sql(self, expression):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def no_paren_current_date_sql(self, expression):
|
def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str:
|
||||||
zone = self.sql(expression, "this")
|
zone = self.sql(expression, "this")
|
||||||
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
|
return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
|
||||||
|
|
||||||
|
|
||||||
def no_recursive_cte_sql(self, expression):
|
def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
|
||||||
if expression.args.get("recursive"):
|
if expression.args.get("recursive"):
|
||||||
self.unsupported("Recursive CTEs are unsupported")
|
self.unsupported("Recursive CTEs are unsupported")
|
||||||
expression.args["recursive"] = False
|
expression.args["recursive"] = False
|
||||||
return self.with_sql(expression)
|
return self.with_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
def no_safe_divide_sql(self, expression):
|
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
|
||||||
n = self.sql(expression, "this")
|
n = self.sql(expression, "this")
|
||||||
d = self.sql(expression, "expression")
|
d = self.sql(expression, "expression")
|
||||||
return f"IF({d} <> 0, {n} / {d}, NULL)"
|
return f"IF({d} <> 0, {n} / {d}, NULL)"
|
||||||
|
|
||||||
|
|
||||||
def no_tablesample_sql(self, expression):
|
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
|
||||||
self.unsupported("TABLESAMPLE unsupported")
|
self.unsupported("TABLESAMPLE unsupported")
|
||||||
return self.sql(expression.this)
|
return self.sql(expression.this)
|
||||||
|
|
||||||
|
|
||||||
def no_pivot_sql(self, expression):
|
def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
|
||||||
self.unsupported("PIVOT unsupported")
|
self.unsupported("PIVOT unsupported")
|
||||||
return self.sql(expression)
|
return self.sql(expression)
|
||||||
|
|
||||||
|
|
||||||
def no_trycast_sql(self, expression):
|
def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
|
||||||
return self.cast_sql(expression)
|
return self.cast_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
def no_properties_sql(self, expression):
|
def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
|
||||||
self.unsupported("Properties unsupported")
|
self.unsupported("Properties unsupported")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def str_position_sql(self, expression):
|
def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
substr = self.sql(expression, "substr")
|
substr = self.sql(expression, "substr")
|
||||||
position = self.sql(expression, "position")
|
position = self.sql(expression, "position")
|
||||||
|
@ -294,13 +304,15 @@ def str_position_sql(self, expression):
|
||||||
return f"STRPOS({this}, {substr})"
|
return f"STRPOS({this}, {substr})"
|
||||||
|
|
||||||
|
|
||||||
def struct_extract_sql(self, expression):
|
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
|
struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
|
||||||
return f"{this}.{struct_key}"
|
return f"{this}.{struct_key}"
|
||||||
|
|
||||||
|
|
||||||
def var_map_sql(self, expression, map_func_name="MAP"):
|
def var_map_sql(
|
||||||
|
self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP"
|
||||||
|
) -> str:
|
||||||
keys = expression.args["keys"]
|
keys = expression.args["keys"]
|
||||||
values = expression.args["values"]
|
values = expression.args["values"]
|
||||||
|
|
||||||
|
@ -315,27 +327,33 @@ def var_map_sql(self, expression, map_func_name="MAP"):
|
||||||
return f"{map_func_name}({self.format_args(*args)})"
|
return f"{map_func_name}({self.format_args(*args)})"
|
||||||
|
|
||||||
|
|
||||||
def format_time_lambda(exp_class, dialect, default=None):
|
def format_time_lambda(
|
||||||
|
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
|
||||||
|
) -> t.Callable[[t.Sequence], E]:
|
||||||
"""Helper used for time expressions.
|
"""Helper used for time expressions.
|
||||||
|
|
||||||
Args
|
Args:
|
||||||
exp_class (Class): the expression class to instantiate
|
exp_class: the expression class to instantiate.
|
||||||
dialect (string): sql dialect
|
dialect: target sql dialect.
|
||||||
default (Option[bool | str]): the default format, True being time
|
default: the default format, True being time.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that can be used to return the appropriately formatted time expression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _format_time(args):
|
def _format_time(args: t.Sequence):
|
||||||
return exp_class(
|
return exp_class(
|
||||||
this=seq_get(args, 0),
|
this=seq_get(args, 0),
|
||||||
format=Dialect[dialect].format_time(
|
format=Dialect[dialect].format_time(
|
||||||
seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
|
seq_get(args, 1)
|
||||||
|
or (Dialect[dialect].time_format if default is True else default or None)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return _format_time
|
return _format_time
|
||||||
|
|
||||||
|
|
||||||
def create_with_partitions_sql(self, expression):
|
def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
|
||||||
"""
|
"""
|
||||||
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the
|
||||||
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
|
PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding
|
||||||
|
@ -359,19 +377,21 @@ def create_with_partitions_sql(self, expression):
|
||||||
return self.create_sql(expression)
|
return self.create_sql(expression)
|
||||||
|
|
||||||
|
|
||||||
def parse_date_delta(exp_class, unit_mapping=None):
|
def parse_date_delta(
|
||||||
def inner_func(args):
|
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
|
||||||
|
) -> t.Callable[[t.Sequence], E]:
|
||||||
|
def inner_func(args: t.Sequence) -> E:
|
||||||
unit_based = len(args) == 3
|
unit_based = len(args) == 3
|
||||||
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
this = seq_get(args, 2) if unit_based else seq_get(args, 0)
|
||||||
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
|
||||||
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
|
||||||
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
|
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore
|
||||||
return exp_class(this=this, expression=expression, unit=unit)
|
return exp_class(this=this, expression=expression, unit=unit)
|
||||||
|
|
||||||
return inner_func
|
return inner_func
|
||||||
|
|
||||||
|
|
||||||
def locate_to_strposition(args):
|
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
|
||||||
return exp.StrPosition(
|
return exp.StrPosition(
|
||||||
this=seq_get(args, 1),
|
this=seq_get(args, 1),
|
||||||
substr=seq_get(args, 0),
|
substr=seq_get(args, 0),
|
||||||
|
@ -379,22 +399,22 @@ def locate_to_strposition(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def strposition_to_locate_sql(self, expression):
|
def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str:
|
||||||
args = self.format_args(
|
args = self.format_args(
|
||||||
expression.args.get("substr"), expression.this, expression.args.get("position")
|
expression.args.get("substr"), expression.this, expression.args.get("position")
|
||||||
)
|
)
|
||||||
return f"LOCATE({args})"
|
return f"LOCATE({args})"
|
||||||
|
|
||||||
|
|
||||||
def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
|
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
|
||||||
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
|
||||||
|
|
||||||
|
|
||||||
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
|
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
|
||||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||||
|
|
||||||
|
|
||||||
def trim_sql(self, expression):
|
def trim_sql(self: Generator, expression: exp.Trim) -> str:
|
||||||
target = self.sql(expression, "this")
|
target = self.sql(expression, "this")
|
||||||
trim_type = self.sql(expression, "position")
|
trim_type = self.sql(expression, "position")
|
||||||
remove_chars = self.sql(expression, "expression")
|
remove_chars = self.sql(expression, "expression")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp, generator, parser, tokens
|
from sqlglot import exp, generator, parser, tokens
|
||||||
from sqlglot.dialects.dialect import (
|
from sqlglot.dialects.dialect import (
|
||||||
|
@ -16,35 +17,29 @@ from sqlglot.dialects.dialect import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_timestamp(args):
|
def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||||
# TO_TIMESTAMP accepts either a single double argument or (text, text)
|
|
||||||
if len(args) == 1 and args[0].is_number:
|
|
||||||
return exp.UnixToTime.from_arg_list(args)
|
|
||||||
return format_time_lambda(exp.StrToTime, "drill")(args)
|
|
||||||
|
|
||||||
|
|
||||||
def _str_to_time_sql(self, expression):
|
|
||||||
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
|
||||||
|
|
||||||
|
|
||||||
def _ts_or_ds_to_date_sql(self, expression):
|
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
|
||||||
time_format = self.format_time(expression)
|
time_format = self.format_time(expression)
|
||||||
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
if time_format and time_format not in (Drill.time_format, Drill.date_format):
|
||||||
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
|
||||||
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
return f"CAST({self.sql(expression, 'this')} AS DATE)"
|
||||||
|
|
||||||
|
|
||||||
def _date_add_sql(kind):
|
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
|
||||||
def func(self, expression):
|
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
unit = expression.text("unit").upper() or "DAY"
|
unit = exp.Var(this=expression.text("unit").upper() or "DAY")
|
||||||
expression = self.sql(expression, "expression")
|
return (
|
||||||
return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
|
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||||
|
)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def if_sql(self, expression):
|
def if_sql(self: generator.Generator, expression: exp.If) -> str:
|
||||||
"""
|
"""
|
||||||
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
Drill requires backticks around certain SQL reserved words, IF being one of them, This function
|
||||||
adds the backticks around the keyword IF.
|
adds the backticks around the keyword IF.
|
||||||
|
@ -61,7 +56,7 @@ def if_sql(self, expression):
|
||||||
return f"`IF`({expressions})"
|
return f"`IF`({expressions})"
|
||||||
|
|
||||||
|
|
||||||
def _str_to_date(self, expression):
|
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
time_format = self.format_time(expression)
|
time_format = self.format_time(expression)
|
||||||
if time_format == Drill.date_format:
|
if time_format == Drill.date_format:
|
||||||
|
@ -111,7 +106,7 @@ class Drill(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'"]
|
QUOTES = ["'"]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPES = ["\\"]
|
STRING_ESCAPES = ["\\"]
|
||||||
ENCODE = "utf-8"
|
ENCODE = "utf-8"
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
|
@ -168,10 +163,10 @@ class Drill(Dialect):
|
||||||
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
|
||||||
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
|
||||||
exp.TryCast: no_trycast_sql,
|
exp.TryCast: no_trycast_sql,
|
||||||
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
|
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
|
||||||
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
|
||||||
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
|
||||||
}
|
}
|
||||||
|
|
||||||
def normalize_func(self, name):
|
def normalize_func(self, name: str) -> str:
|
||||||
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
|
||||||
|
|
|
@ -25,10 +25,9 @@ def _str_to_time_sql(self, expression):
|
||||||
|
|
||||||
|
|
||||||
def _ts_or_ds_add(self, expression):
|
def _ts_or_ds_add(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = expression.args.get("this")
|
||||||
e = self.sql(expression, "expression")
|
|
||||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||||
return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
|
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||||
|
|
||||||
|
|
||||||
def _ts_or_ds_to_date_sql(self, expression):
|
def _ts_or_ds_to_date_sql(self, expression):
|
||||||
|
@ -40,9 +39,8 @@ def _ts_or_ds_to_date_sql(self, expression):
|
||||||
|
|
||||||
def _date_add(self, expression):
|
def _date_add(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
e = self.sql(expression, "expression")
|
|
||||||
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
unit = self.sql(expression, "unit").strip("'") or "DAY"
|
||||||
return f"{this} + INTERVAL {e} {unit}"
|
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
|
||||||
|
|
||||||
|
|
||||||
def _array_sort_sql(self, expression):
|
def _array_sort_sql(self, expression):
|
||||||
|
|
|
@ -172,7 +172,7 @@ class Hive(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'", '"']
|
QUOTES = ["'", '"']
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPES = ["\\"]
|
STRING_ESCAPES = ["\\"]
|
||||||
ENCODE = "utf-8"
|
ENCODE = "utf-8"
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
|
|
|
@ -89,8 +89,9 @@ def _date_add_sql(kind):
|
||||||
def func(self, expression):
|
def func(self, expression):
|
||||||
this = self.sql(expression, "this")
|
this = self.sql(expression, "this")
|
||||||
unit = expression.text("unit").upper() or "DAY"
|
unit = expression.text("unit").upper() or "DAY"
|
||||||
expression = self.sql(expression, "expression")
|
return (
|
||||||
return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
|
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
|
||||||
|
)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
@ -117,7 +118,7 @@ class MySQL(Dialect):
|
||||||
QUOTES = ["'", '"']
|
QUOTES = ["'", '"']
|
||||||
COMMENTS = ["--", "#", ("/*", "*/")]
|
COMMENTS = ["--", "#", ("/*", "*/")]
|
||||||
IDENTIFIERS = ["`"]
|
IDENTIFIERS = ["`"]
|
||||||
ESCAPES = ["'", "\\"]
|
STRING_ESCAPES = ["'", "\\"]
|
||||||
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
|
||||||
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
|
||||||
|
|
||||||
|
|
|
@ -40,8 +40,7 @@ def _date_add_sql(kind):
|
||||||
|
|
||||||
expression = expression.copy()
|
expression = expression.copy()
|
||||||
expression.args["is_string"] = True
|
expression.args["is_string"] = True
|
||||||
expression = self.sql(expression)
|
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
|
||||||
return f"{this} {kind} INTERVAL {expression} {unit}"
|
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
|
@ -37,11 +37,10 @@ class Redshift(Postgres):
|
||||||
return this
|
return this
|
||||||
|
|
||||||
class Tokenizer(Postgres.Tokenizer):
|
class Tokenizer(Postgres.Tokenizer):
|
||||||
ESCAPES = ["\\"]
|
STRING_ESCAPES = ["\\"]
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
**Postgres.Tokenizer.KEYWORDS, # type: ignore
|
||||||
"COPY": TokenType.COMMAND,
|
|
||||||
"ENCODE": TokenType.ENCODE,
|
"ENCODE": TokenType.ENCODE,
|
||||||
"GEOMETRY": TokenType.GEOMETRY,
|
"GEOMETRY": TokenType.GEOMETRY,
|
||||||
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
"GEOGRAPHY": TokenType.GEOGRAPHY,
|
||||||
|
|
|
@ -180,7 +180,7 @@ class Snowflake(Dialect):
|
||||||
|
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
QUOTES = ["'", "$$"]
|
QUOTES = ["'", "$$"]
|
||||||
ESCAPES = ["\\", "'"]
|
STRING_ESCAPES = ["\\", "'"]
|
||||||
|
|
||||||
SINGLE_TOKENS = {
|
SINGLE_TOKENS = {
|
||||||
**tokens.Tokenizer.SINGLE_TOKENS,
|
**tokens.Tokenizer.SINGLE_TOKENS,
|
||||||
|
@ -191,6 +191,7 @@ class Snowflake(Dialect):
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"EXCLUDE": TokenType.EXCEPT,
|
"EXCLUDE": TokenType.EXCEPT,
|
||||||
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
|
||||||
|
"PUT": TokenType.COMMAND,
|
||||||
"RENAME": TokenType.REPLACE,
|
"RENAME": TokenType.REPLACE,
|
||||||
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
|
||||||
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
|
||||||
|
@ -222,6 +223,7 @@ class Snowflake(Dialect):
|
||||||
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
|
||||||
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
|
||||||
exp.UnixToTime: _unix_to_time_sql,
|
exp.UnixToTime: _unix_to_time_sql,
|
||||||
|
exp.DayOfWeek: rename_func("DAYOFWEEK"),
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
|
@ -294,3 +296,12 @@ class Snowflake(Dialect):
|
||||||
kind = f" {kind_value}" if kind_value else ""
|
kind = f" {kind_value}" if kind_value else ""
|
||||||
this = f" {self.sql(expression, 'this')}"
|
this = f" {self.sql(expression, 'this')}"
|
||||||
return f"DESCRIBE{kind}{this}"
|
return f"DESCRIBE{kind}{this}"
|
||||||
|
|
||||||
|
def generatedasidentitycolumnconstraint_sql(
|
||||||
|
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||||
|
) -> str:
|
||||||
|
start = expression.args.get("start")
|
||||||
|
start = f" START {start}" if start else ""
|
||||||
|
increment = expression.args.get("increment")
|
||||||
|
increment = f" INCREMENT {increment}" if increment else ""
|
||||||
|
return f"AUTOINCREMENT{start}{increment}"
|
||||||
|
|
|
@ -157,6 +157,7 @@ class Spark(Hive):
|
||||||
TRANSFORMS.pop(exp.ILike)
|
TRANSFORMS.pop(exp.ILike)
|
||||||
|
|
||||||
WRAP_DERIVED_VALUES = False
|
WRAP_DERIVED_VALUES = False
|
||||||
|
CREATE_FUNCTION_AS = False
|
||||||
|
|
||||||
def cast_sql(self, expression: exp.Cast) -> str:
|
def cast_sql(self, expression: exp.Cast) -> str:
|
||||||
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
|
||||||
|
|
|
@ -49,7 +49,6 @@ class SQLite(Dialect):
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**tokens.Tokenizer.KEYWORDS,
|
**tokens.Tokenizer.KEYWORDS,
|
||||||
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class Parser(parser.Parser):
|
class Parser(parser.Parser):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../posts/sql_diff.md
|
.. include:: ../posts/sql_diff.md
|
||||||
|
|
||||||
----
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,17 @@ from sqlglot.helper import AutoName
|
||||||
|
|
||||||
|
|
||||||
class ErrorLevel(AutoName):
|
class ErrorLevel(AutoName):
|
||||||
IGNORE = auto() # Ignore any parser errors
|
IGNORE = auto()
|
||||||
WARN = auto() # Log any parser errors with ERROR level
|
"""Ignore all errors."""
|
||||||
RAISE = auto() # Collect all parser errors and raise a single exception
|
|
||||||
IMMEDIATE = auto() # Immediately raise an exception on the first parser error
|
WARN = auto()
|
||||||
|
"""Log all errors."""
|
||||||
|
|
||||||
|
RAISE = auto()
|
||||||
|
"""Collect all errors and raise a single exception."""
|
||||||
|
|
||||||
|
IMMEDIATE = auto()
|
||||||
|
"""Immediately raise an exception on the first error found."""
|
||||||
|
|
||||||
|
|
||||||
class SqlglotError(Exception):
|
class SqlglotError(Exception):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
.. include:: ../../posts/python_sql_engine.md
|
.. include:: ../../posts/python_sql_engine.md
|
||||||
|
|
||||||
----
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -408,7 +408,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
|
||||||
|
|
||||||
class Python(Dialect):
|
class Python(Dialect):
|
||||||
class Tokenizer(tokens.Tokenizer):
|
class Tokenizer(tokens.Tokenizer):
|
||||||
ESCAPES = ["\\"]
|
STRING_ESCAPES = ["\\"]
|
||||||
|
|
||||||
class Generator(generator.Generator):
|
class Generator(generator.Generator):
|
||||||
TRANSFORMS = {
|
TRANSFORMS = {
|
||||||
|
|
|
@ -6,6 +6,7 @@ Every AST node in SQLGlot is represented by a subclass of `Expression`.
|
||||||
This module contains the implementation of all supported `Expression` types. Additionally,
|
This module contains the implementation of all supported `Expression` types. Additionally,
|
||||||
it exposes a number of helper functions, which are mainly used to programmatically build
|
it exposes a number of helper functions, which are mainly used to programmatically build
|
||||||
SQL expressions, such as `sqlglot.expressions.select`.
|
SQL expressions, such as `sqlglot.expressions.select`.
|
||||||
|
|
||||||
----
|
----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -137,6 +138,8 @@ class Expression(metaclass=_Expression):
|
||||||
return field
|
return field
|
||||||
if isinstance(field, (Identifier, Literal, Var)):
|
if isinstance(field, (Identifier, Literal, Var)):
|
||||||
return field.this
|
return field.this
|
||||||
|
if isinstance(field, (Star, Null)):
|
||||||
|
return field.name
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -176,13 +179,11 @@ class Expression(metaclass=_Expression):
|
||||||
return self.text("alias")
|
return self.text("alias")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.text("this")
|
return self.text("this")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias_or_name(self):
|
def alias_or_name(self):
|
||||||
if isinstance(self, Null):
|
|
||||||
return "NULL"
|
|
||||||
return self.alias or self.name
|
return self.alias or self.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -589,12 +590,11 @@ class Expression(metaclass=_Expression):
|
||||||
return load(obj)
|
return load(obj)
|
||||||
|
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
IntoType = t.Union[
|
||||||
IntoType = t.Union[
|
str,
|
||||||
str,
|
t.Type[Expression],
|
||||||
t.Type[Expression],
|
t.Collection[t.Union[str, t.Type[Expression]]],
|
||||||
t.Collection[t.Union[str, t.Type[Expression]]],
|
]
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Condition(Expression):
|
class Condition(Expression):
|
||||||
|
@ -939,7 +939,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
|
||||||
|
|
||||||
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
|
||||||
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
# this: True -> ALWAYS, this: False -> BY DEFAULT
|
||||||
arg_types = {"this": True, "start": False, "increment": False}
|
arg_types = {"this": False, "start": False, "increment": False}
|
||||||
|
|
||||||
|
|
||||||
class NotNullColumnConstraint(ColumnConstraintKind):
|
class NotNullColumnConstraint(ColumnConstraintKind):
|
||||||
|
@ -2390,7 +2390,7 @@ class Star(Expression):
|
||||||
arg_types = {"except": False, "replace": False}
|
arg_types = {"except": False, "replace": False}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return "*"
|
return "*"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2413,6 +2413,10 @@ class Placeholder(Expression):
|
||||||
class Null(Condition):
|
class Null(Condition):
|
||||||
arg_types: t.Dict[str, t.Any] = {}
|
arg_types: t.Dict[str, t.Any] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "NULL"
|
||||||
|
|
||||||
|
|
||||||
class Boolean(Condition):
|
class Boolean(Condition):
|
||||||
pass
|
pass
|
||||||
|
@ -2644,7 +2648,9 @@ class Div(Binary):
|
||||||
|
|
||||||
|
|
||||||
class Dot(Binary):
|
class Dot(Binary):
|
||||||
pass
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.expression.name
|
||||||
|
|
||||||
|
|
||||||
class DPipe(Binary):
|
class DPipe(Binary):
|
||||||
|
@ -2961,7 +2967,7 @@ class Cast(Func):
|
||||||
arg_types = {"this": True, "to": True}
|
arg_types = {"this": True, "to": True}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.this.name
|
return self.this.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -4027,17 +4033,39 @@ def paren(expression) -> Paren:
|
||||||
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
|
||||||
|
|
||||||
|
|
||||||
def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
|
@t.overload
|
||||||
if alias is None:
|
def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@t.overload
|
||||||
|
def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def to_identifier(name, quoted=None):
|
||||||
|
"""Builds an identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to turn into an identifier.
|
||||||
|
quoted: Whether or not force quote the identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The identifier ast node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if name is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(alias, Identifier):
|
|
||||||
identifier = alias
|
if isinstance(name, Identifier):
|
||||||
elif isinstance(alias, str):
|
identifier = name
|
||||||
if quoted is None:
|
elif isinstance(name, str):
|
||||||
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
|
identifier = Identifier(
|
||||||
identifier = Identifier(this=alias, quoted=quoted)
|
this=name,
|
||||||
|
quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
|
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
|
||||||
return identifier
|
return identifier
|
||||||
|
|
||||||
|
|
||||||
|
@ -4112,20 +4140,31 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
|
||||||
return Column(this=column_name, table=table_name, **kwargs)
|
return Column(this=column_name, table=table_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
def alias_(
|
||||||
"""
|
expression: str | Expression,
|
||||||
Create an Alias expression.
|
alias: str | Identifier,
|
||||||
|
table: bool | t.Sequence[str | Identifier] = False,
|
||||||
|
quoted: t.Optional[bool] = None,
|
||||||
|
dialect: DialectType = None,
|
||||||
|
**opts,
|
||||||
|
):
|
||||||
|
"""Create an Alias expression.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> alias_('foo', 'bar').sql()
|
>>> alias_('foo', 'bar').sql()
|
||||||
'foo AS bar'
|
'foo AS bar'
|
||||||
|
|
||||||
|
>>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql()
|
||||||
|
'(SELECT 1, 2) AS bar(a, b)'
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expression (str | Expression): the SQL code strings to parse.
|
expression: the SQL code strings to parse.
|
||||||
If an Expression instance is passed, this is used as-is.
|
If an Expression instance is passed, this is used as-is.
|
||||||
alias (str | Identifier): the alias name to use. If the name has
|
alias: the alias name to use. If the name has
|
||||||
special characters it is quoted.
|
special characters it is quoted.
|
||||||
table (bool): create a table alias, default false
|
table: Whether or not to create a table alias, can also be a list of columns.
|
||||||
dialect (str): the dialect used to parse the input expression.
|
quoted: whether or not to quote the alias
|
||||||
|
dialect: the dialect used to parse the input expression.
|
||||||
**opts: other options to use to parse the input expressions.
|
**opts: other options to use to parse the input expressions.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -4135,8 +4174,14 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
|
||||||
alias = to_identifier(alias, quoted=quoted)
|
alias = to_identifier(alias, quoted=quoted)
|
||||||
|
|
||||||
if table:
|
if table:
|
||||||
expression.set("alias", TableAlias(this=alias))
|
table_alias = TableAlias(this=alias)
|
||||||
return expression
|
exp.set("alias", table_alias)
|
||||||
|
|
||||||
|
if not isinstance(table, bool):
|
||||||
|
for column in table:
|
||||||
|
table_alias.append("columns", to_identifier(column, quoted=quoted))
|
||||||
|
|
||||||
|
return exp
|
||||||
|
|
||||||
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
|
# We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in
|
||||||
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
|
# the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from sqlglot import exp
|
from sqlglot import exp
|
||||||
|
@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType
|
||||||
|
|
||||||
logger = logging.getLogger("sqlglot")
|
logger = logging.getLogger("sqlglot")
|
||||||
|
|
||||||
|
BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +31,8 @@ class Generator:
|
||||||
identify (bool): if set to True all identifiers will be delimited by the corresponding
|
identify (bool): if set to True all identifiers will be delimited by the corresponding
|
||||||
character.
|
character.
|
||||||
normalize (bool): if set to True all identifiers will lower cased
|
normalize (bool): if set to True all identifiers will lower cased
|
||||||
escape (str): specifies an escape character. Default: '.
|
string_escape (str): specifies a string escape character. Default: '.
|
||||||
|
identifier_escape (str): specifies an identifier escape character. Default: ".
|
||||||
pad (int): determines padding in a formatted string. Default: 2.
|
pad (int): determines padding in a formatted string. Default: 2.
|
||||||
indent (int): determines the size of indentation in a formatted string. Default: 4.
|
indent (int): determines the size of indentation in a formatted string. Default: 4.
|
||||||
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
|
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
|
||||||
|
@ -85,6 +89,9 @@ class Generator:
|
||||||
# Wrap derived values in parens, usually standard but spark doesn't support it
|
# Wrap derived values in parens, usually standard but spark doesn't support it
|
||||||
WRAP_DERIVED_VALUES = True
|
WRAP_DERIVED_VALUES = True
|
||||||
|
|
||||||
|
# Whether or not create function uses an AS before the def.
|
||||||
|
CREATE_FUNCTION_AS = True
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
exp.DataType.Type.NCHAR: "CHAR",
|
exp.DataType.Type.NCHAR: "CHAR",
|
||||||
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
exp.DataType.Type.NVARCHAR: "VARCHAR",
|
||||||
|
@ -154,7 +161,8 @@ class Generator:
|
||||||
"identifier_end",
|
"identifier_end",
|
||||||
"identify",
|
"identify",
|
||||||
"normalize",
|
"normalize",
|
||||||
"escape",
|
"string_escape",
|
||||||
|
"identifier_escape",
|
||||||
"pad",
|
"pad",
|
||||||
"index_offset",
|
"index_offset",
|
||||||
"unnest_column_only",
|
"unnest_column_only",
|
||||||
|
@ -167,6 +175,7 @@ class Generator:
|
||||||
"_indent",
|
"_indent",
|
||||||
"_replace_backslash",
|
"_replace_backslash",
|
||||||
"_escaped_quote_end",
|
"_escaped_quote_end",
|
||||||
|
"_escaped_identifier_end",
|
||||||
"_leading_comma",
|
"_leading_comma",
|
||||||
"_max_text_width",
|
"_max_text_width",
|
||||||
"_comments",
|
"_comments",
|
||||||
|
@ -183,7 +192,8 @@ class Generator:
|
||||||
identifier_end=None,
|
identifier_end=None,
|
||||||
identify=False,
|
identify=False,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
escape=None,
|
string_escape=None,
|
||||||
|
identifier_escape=None,
|
||||||
pad=2,
|
pad=2,
|
||||||
indent=2,
|
indent=2,
|
||||||
index_offset=0,
|
index_offset=0,
|
||||||
|
@ -208,7 +218,8 @@ class Generator:
|
||||||
self.identifier_end = identifier_end or '"'
|
self.identifier_end = identifier_end or '"'
|
||||||
self.identify = identify
|
self.identify = identify
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.escape = escape or "'"
|
self.string_escape = string_escape or "'"
|
||||||
|
self.identifier_escape = identifier_escape or '"'
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.index_offset = index_offset
|
self.index_offset = index_offset
|
||||||
self.unnest_column_only = unnest_column_only
|
self.unnest_column_only = unnest_column_only
|
||||||
|
@ -219,8 +230,9 @@ class Generator:
|
||||||
self.max_unsupported = max_unsupported
|
self.max_unsupported = max_unsupported
|
||||||
self.null_ordering = null_ordering
|
self.null_ordering = null_ordering
|
||||||
self._indent = indent
|
self._indent = indent
|
||||||
self._replace_backslash = self.escape == "\\"
|
self._replace_backslash = self.string_escape == "\\"
|
||||||
self._escaped_quote_end = self.escape + self.quote_end
|
self._escaped_quote_end = self.string_escape + self.quote_end
|
||||||
|
self._escaped_identifier_end = self.identifier_escape + self.identifier_end
|
||||||
self._leading_comma = leading_comma
|
self._leading_comma = leading_comma
|
||||||
self._max_text_width = max_text_width
|
self._max_text_width = max_text_width
|
||||||
self._comments = comments
|
self._comments = comments
|
||||||
|
@ -441,6 +453,9 @@ class Generator:
|
||||||
def generatedasidentitycolumnconstraint_sql(
|
def generatedasidentitycolumnconstraint_sql(
|
||||||
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
self, expression: exp.GeneratedAsIdentityColumnConstraint
|
||||||
) -> str:
|
) -> str:
|
||||||
|
this = ""
|
||||||
|
if expression.this is not None:
|
||||||
|
this = " ALWAYS " if expression.this else " BY DEFAULT "
|
||||||
start = expression.args.get("start")
|
start = expression.args.get("start")
|
||||||
start = f"START WITH {start}" if start else ""
|
start = f"START WITH {start}" if start else ""
|
||||||
increment = expression.args.get("increment")
|
increment = expression.args.get("increment")
|
||||||
|
@ -449,9 +464,7 @@ class Generator:
|
||||||
if start or increment:
|
if start or increment:
|
||||||
sequence_opts = f"{start} {increment}"
|
sequence_opts = f"{start} {increment}"
|
||||||
sequence_opts = f" ({sequence_opts.strip()})"
|
sequence_opts = f" ({sequence_opts.strip()})"
|
||||||
return (
|
return f"GENERATED{this}AS IDENTITY{sequence_opts}"
|
||||||
f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
|
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
|
||||||
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
|
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
|
||||||
|
@ -496,7 +509,12 @@ class Generator:
|
||||||
properties_sql = self.sql(properties_exp, "properties")
|
properties_sql = self.sql(properties_exp, "properties")
|
||||||
begin = " BEGIN" if expression.args.get("begin") else ""
|
begin = " BEGIN" if expression.args.get("begin") else ""
|
||||||
expression_sql = self.sql(expression, "expression")
|
expression_sql = self.sql(expression, "expression")
|
||||||
expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
|
if expression_sql:
|
||||||
|
expression_sql = f"{begin}{self.sep()}{expression_sql}"
|
||||||
|
|
||||||
|
if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
|
||||||
|
expression_sql = f" AS{expression_sql}"
|
||||||
|
|
||||||
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
|
||||||
transient = (
|
transient = (
|
||||||
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
|
||||||
|
@ -701,6 +719,7 @@ class Generator:
|
||||||
def identifier_sql(self, expression: exp.Identifier) -> str:
|
def identifier_sql(self, expression: exp.Identifier) -> str:
|
||||||
text = expression.name
|
text = expression.name
|
||||||
text = text.lower() if self.normalize else text
|
text = text.lower() if self.normalize else text
|
||||||
|
text = text.replace(self.identifier_end, self._escaped_identifier_end)
|
||||||
if expression.args.get("quoted") or self.identify:
|
if expression.args.get("quoted") or self.identify:
|
||||||
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
text = f"{self.identifier_start}{text}{self.identifier_end}"
|
||||||
return text
|
return text
|
||||||
|
@ -1121,7 +1140,7 @@ class Generator:
|
||||||
text = expression.this or ""
|
text = expression.this or ""
|
||||||
if expression.is_string:
|
if expression.is_string:
|
||||||
if self._replace_backslash:
|
if self._replace_backslash:
|
||||||
text = text.replace("\\", "\\\\")
|
text = BACKSLASH_RE.sub(r"\\\\", text)
|
||||||
text = text.replace(self.quote_end, self._escaped_quote_end)
|
text = text.replace(self.quote_end, self._escaped_quote_end)
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
|
||||||
|
@ -1486,9 +1505,16 @@ class Generator:
|
||||||
return f"(SELECT {self.sql(unnest)})"
|
return f"(SELECT {self.sql(unnest)})"
|
||||||
|
|
||||||
def interval_sql(self, expression: exp.Interval) -> str:
|
def interval_sql(self, expression: exp.Interval) -> str:
|
||||||
this = self.sql(expression, "this")
|
this = expression.args.get("this")
|
||||||
this = f" {this}" if this else ""
|
if this:
|
||||||
unit = self.sql(expression, "unit")
|
this = (
|
||||||
|
f" {this}"
|
||||||
|
if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
|
||||||
|
else f" ({this})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
this = ""
|
||||||
|
unit = expression.args.get("unit")
|
||||||
unit = f" {unit}" if unit else ""
|
unit = f" {unit}" if unit else ""
|
||||||
return f"INTERVAL{this}{unit}"
|
return f"INTERVAL{this}{unit}"
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from sqlglot import Schema, exp, maybe_parse
|
from sqlglot import Schema, exp, maybe_parse
|
||||||
from sqlglot.optimizer import Scope, build_scope, optimize
|
from sqlglot.optimizer import Scope, build_scope, optimize
|
||||||
|
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
from sqlglot.optimizer.qualify_columns import qualify_columns
|
||||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ def lineage(
|
||||||
sql: str | exp.Expression,
|
sql: str | exp.Expression,
|
||||||
schema: t.Optional[t.Dict | Schema] = None,
|
schema: t.Optional[t.Dict | Schema] = None,
|
||||||
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
|
||||||
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns),
|
rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals),
|
||||||
dialect: DialectType = None,
|
dialect: DialectType = None,
|
||||||
) -> Node:
|
) -> Node:
|
||||||
"""Build the lineage graph for a column of a SQL query.
|
"""Build the lineage graph for a column of a SQL query.
|
||||||
|
|
|
@ -255,12 +255,23 @@ class TypeAnnotator:
|
||||||
for name, source in scope.sources.items():
|
for name, source in scope.sources.items():
|
||||||
if not isinstance(source, Scope):
|
if not isinstance(source, Scope):
|
||||||
continue
|
continue
|
||||||
if isinstance(source.expression, exp.Values):
|
if isinstance(source.expression, exp.UDTF):
|
||||||
|
values = []
|
||||||
|
|
||||||
|
if isinstance(source.expression, exp.Lateral):
|
||||||
|
if isinstance(source.expression.this, exp.Explode):
|
||||||
|
values = [source.expression.this.this]
|
||||||
|
else:
|
||||||
|
values = source.expression.expressions[0].expressions
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
continue
|
||||||
|
|
||||||
selects[name] = {
|
selects[name] = {
|
||||||
alias: column
|
alias: column
|
||||||
for alias, column in zip(
|
for alias, column in zip(
|
||||||
source.expression.alias_column_names,
|
source.expression.alias_column_names,
|
||||||
source.expression.expressions[0].expressions,
|
values,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -272,7 +283,7 @@ class TypeAnnotator:
|
||||||
source = scope.sources.get(col.table)
|
source = scope.sources.get(col.table)
|
||||||
if isinstance(source, exp.Table):
|
if isinstance(source, exp.Table):
|
||||||
col.type = self.schema.get_column_type(source, col)
|
col.type = self.schema.get_column_type(source, col)
|
||||||
elif source:
|
elif source and col.table in selects:
|
||||||
col.type = selects[col.table][col.name].type
|
col.type = selects[col.table][col.name].type
|
||||||
# Then (possibly) annotate the remaining expressions in the scope
|
# Then (possibly) annotate the remaining expressions in the scope
|
||||||
self._maybe_annotate(scope.expression)
|
self._maybe_annotate(scope.expression)
|
||||||
|
|
34
sqlglot/optimizer/expand_laterals.py
Normal file
34
sqlglot/optimizer/expand_laterals.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from sqlglot import exp
|
||||||
|
|
||||||
|
|
||||||
|
def expand_laterals(expression: exp.Expression) -> exp.Expression:
|
||||||
|
"""
|
||||||
|
Expand lateral column alias references.
|
||||||
|
|
||||||
|
This assumes `qualify_columns` as already run.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import sqlglot
|
||||||
|
>>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x"
|
||||||
|
>>> expression = sqlglot.parse_one(sql)
|
||||||
|
>>> expand_laterals(expression).sql()
|
||||||
|
'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: expression to optimize
|
||||||
|
Returns:
|
||||||
|
optimized expression
|
||||||
|
"""
|
||||||
|
for select in expression.find_all(exp.Select):
|
||||||
|
alias_to_expression: t.Dict[str, exp.Expression] = {}
|
||||||
|
for projection in select.expressions:
|
||||||
|
for column in projection.find_all(exp.Column):
|
||||||
|
if not column.table and column.name in alias_to_expression:
|
||||||
|
column.replace(alias_to_expression[column.name].copy())
|
||||||
|
if isinstance(projection, exp.Alias):
|
||||||
|
alias_to_expression[projection.alias] = projection.this
|
||||||
|
return expression
|
|
@ -4,6 +4,7 @@ from sqlglot.optimizer.canonicalize import canonicalize
|
||||||
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
|
||||||
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
from sqlglot.optimizer.eliminate_joins import eliminate_joins
|
||||||
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
|
||||||
|
from sqlglot.optimizer.expand_laterals import expand_laterals
|
||||||
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
|
||||||
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
|
||||||
from sqlglot.optimizer.lower_identities import lower_identities
|
from sqlglot.optimizer.lower_identities import lower_identities
|
||||||
|
@ -12,7 +13,7 @@ from sqlglot.optimizer.normalize import normalize
|
||||||
from sqlglot.optimizer.optimize_joins import optimize_joins
|
from sqlglot.optimizer.optimize_joins import optimize_joins
|
||||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||||
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
||||||
from sqlglot.optimizer.qualify_columns import qualify_columns
|
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
|
||||||
from sqlglot.optimizer.qualify_tables import qualify_tables
|
from sqlglot.optimizer.qualify_tables import qualify_tables
|
||||||
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
|
||||||
from sqlglot.schema import ensure_schema
|
from sqlglot.schema import ensure_schema
|
||||||
|
@ -22,6 +23,8 @@ RULES = (
|
||||||
qualify_tables,
|
qualify_tables,
|
||||||
isolate_table_selects,
|
isolate_table_selects,
|
||||||
qualify_columns,
|
qualify_columns,
|
||||||
|
expand_laterals,
|
||||||
|
validate_qualify_columns,
|
||||||
pushdown_projections,
|
pushdown_projections,
|
||||||
normalize,
|
normalize,
|
||||||
unnest_subqueries,
|
unnest_subqueries,
|
||||||
|
|
|
@ -7,7 +7,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||||
SELECT_ALL = object()
|
SELECT_ALL = object()
|
||||||
|
|
||||||
# Selection to use if selection list is empty
|
# Selection to use if selection list is empty
|
||||||
DEFAULT_SELECTION = alias("1", "_")
|
DEFAULT_SELECTION = lambda: alias("1", "_")
|
||||||
|
|
||||||
|
|
||||||
def pushdown_projections(expression):
|
def pushdown_projections(expression):
|
||||||
|
@ -93,7 +93,7 @@ def _remove_unused_selections(scope, parent_selections):
|
||||||
|
|
||||||
# If there are no remaining selections, just select a single constant
|
# If there are no remaining selections, just select a single constant
|
||||||
if not new_selections:
|
if not new_selections:
|
||||||
new_selections.append(DEFAULT_SELECTION.copy())
|
new_selections.append(DEFAULT_SELECTION())
|
||||||
|
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
if removed:
|
if removed:
|
||||||
|
@ -106,5 +106,5 @@ def _remove_indexed_selections(scope, indexes_to_remove):
|
||||||
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
|
||||||
]
|
]
|
||||||
if not new_selections:
|
if not new_selections:
|
||||||
new_selections.append(DEFAULT_SELECTION.copy())
|
new_selections.append(DEFAULT_SELECTION())
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
|
|
|
@ -37,11 +37,24 @@ def qualify_columns(expression, schema):
|
||||||
if not isinstance(scope.expression, exp.UDTF):
|
if not isinstance(scope.expression, exp.UDTF):
|
||||||
_expand_stars(scope, resolver)
|
_expand_stars(scope, resolver)
|
||||||
_qualify_outputs(scope)
|
_qualify_outputs(scope)
|
||||||
_check_unknown_tables(scope)
|
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
|
|
||||||
|
def validate_qualify_columns(expression):
|
||||||
|
"""Raise an `OptimizeError` if any columns aren't qualified"""
|
||||||
|
unqualified_columns = []
|
||||||
|
for scope in traverse_scope(expression):
|
||||||
|
if isinstance(scope.expression, exp.Select):
|
||||||
|
unqualified_columns.extend(scope.unqualified_columns)
|
||||||
|
if scope.external_columns and not scope.is_correlated_subquery:
|
||||||
|
raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
|
||||||
|
|
||||||
|
if unqualified_columns:
|
||||||
|
raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
|
||||||
|
return expression
|
||||||
|
|
||||||
|
|
||||||
def _pop_table_column_aliases(derived_tables):
|
def _pop_table_column_aliases(derived_tables):
|
||||||
"""
|
"""
|
||||||
Remove table column aliases.
|
Remove table column aliases.
|
||||||
|
@ -199,10 +212,6 @@ def _qualify_columns(scope, resolver):
|
||||||
if not column_table:
|
if not column_table:
|
||||||
column_table = resolver.get_table(column_name)
|
column_table = resolver.get_table(column_name)
|
||||||
|
|
||||||
if not scope.is_subquery and not scope.is_udtf:
|
|
||||||
if column_table is None:
|
|
||||||
raise OptimizeError(f"Ambiguous column: {column_name}")
|
|
||||||
|
|
||||||
# column_table can be a '' because bigquery unnest has no table alias
|
# column_table can be a '' because bigquery unnest has no table alias
|
||||||
if column_table:
|
if column_table:
|
||||||
column.set("table", exp.to_identifier(column_table))
|
column.set("table", exp.to_identifier(column_table))
|
||||||
|
@ -231,10 +240,8 @@ def _qualify_columns(scope, resolver):
|
||||||
for column in columns_missing_from_scope:
|
for column in columns_missing_from_scope:
|
||||||
column_table = resolver.get_table(column.name)
|
column_table = resolver.get_table(column.name)
|
||||||
|
|
||||||
if column_table is None:
|
if column_table:
|
||||||
raise OptimizeError(f"Ambiguous column: {column.name}")
|
column.set("table", exp.to_identifier(column_table))
|
||||||
|
|
||||||
column.set("table", exp.to_identifier(column_table))
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_stars(scope, resolver):
|
def _expand_stars(scope, resolver):
|
||||||
|
@ -322,11 +329,6 @@ def _qualify_outputs(scope):
|
||||||
scope.expression.set("expressions", new_selections)
|
scope.expression.set("expressions", new_selections)
|
||||||
|
|
||||||
|
|
||||||
def _check_unknown_tables(scope):
|
|
||||||
if scope.external_columns and not scope.is_udtf and not scope.is_correlated_subquery:
|
|
||||||
raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
|
|
||||||
|
|
||||||
|
|
||||||
class _Resolver:
|
class _Resolver:
|
||||||
"""
|
"""
|
||||||
Helper for resolving columns.
|
Helper for resolving columns.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import itertools
|
||||||
|
|
||||||
from sqlglot import alias, exp
|
from sqlglot import alias, exp
|
||||||
from sqlglot.helper import csv_reader
|
from sqlglot.helper import csv_reader
|
||||||
from sqlglot.optimizer.scope import traverse_scope
|
from sqlglot.optimizer.scope import Scope, traverse_scope
|
||||||
|
|
||||||
|
|
||||||
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||||
|
@ -25,6 +25,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||||
"""
|
"""
|
||||||
sequence = itertools.count()
|
sequence = itertools.count()
|
||||||
|
|
||||||
|
next_name = lambda: f"_q_{next(sequence)}"
|
||||||
|
|
||||||
for scope in traverse_scope(expression):
|
for scope in traverse_scope(expression):
|
||||||
for derived_table in scope.ctes + scope.derived_tables:
|
for derived_table in scope.ctes + scope.derived_tables:
|
||||||
if not derived_table.args.get("alias"):
|
if not derived_table.args.get("alias"):
|
||||||
|
@ -46,7 +48,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||||
source = source.replace(
|
source = source.replace(
|
||||||
alias(
|
alias(
|
||||||
source.copy(),
|
source.copy(),
|
||||||
source.this if identifier else f"_q_{next(sequence)}",
|
source.this if identifier else next_name(),
|
||||||
table=True,
|
table=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -58,5 +60,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
|
||||||
schema.add_table(
|
schema.add_table(
|
||||||
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
source, {k: type(v).__name__ for k, v in zip(header, columns)}
|
||||||
)
|
)
|
||||||
|
elif isinstance(source, Scope) and source.is_udtf:
|
||||||
|
udtf = source.expression
|
||||||
|
table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name())
|
||||||
|
udtf.set("alias", table_alias)
|
||||||
|
|
||||||
|
if not table_alias.name:
|
||||||
|
table_alias.set("this", next_name())
|
||||||
|
|
||||||
return expression
|
return expression
|
||||||
|
|
|
@ -237,6 +237,8 @@ class Scope:
|
||||||
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
|
ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
|
||||||
if (
|
if (
|
||||||
not ancestor
|
not ancestor
|
||||||
|
# Window functions can have an ORDER BY clause
|
||||||
|
or not isinstance(ancestor.parent, exp.Select)
|
||||||
or column.table
|
or column.table
|
||||||
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
|
||||||
):
|
):
|
||||||
|
@ -479,7 +481,7 @@ def _traverse_scope(scope):
|
||||||
elif isinstance(scope.expression, exp.Union):
|
elif isinstance(scope.expression, exp.Union):
|
||||||
yield from _traverse_union(scope)
|
yield from _traverse_union(scope)
|
||||||
elif isinstance(scope.expression, exp.UDTF):
|
elif isinstance(scope.expression, exp.UDTF):
|
||||||
pass
|
_set_udtf_scope(scope)
|
||||||
elif isinstance(scope.expression, exp.Subquery):
|
elif isinstance(scope.expression, exp.Subquery):
|
||||||
yield from _traverse_subqueries(scope)
|
yield from _traverse_subqueries(scope)
|
||||||
else:
|
else:
|
||||||
|
@ -509,6 +511,22 @@ def _traverse_union(scope):
|
||||||
scope.union_scopes = [left, right]
|
scope.union_scopes = [left, right]
|
||||||
|
|
||||||
|
|
||||||
|
def _set_udtf_scope(scope):
|
||||||
|
parent = scope.expression.parent
|
||||||
|
from_ = parent.args.get("from")
|
||||||
|
|
||||||
|
if not from_:
|
||||||
|
return
|
||||||
|
|
||||||
|
for table in from_.expressions:
|
||||||
|
if isinstance(table, exp.Table):
|
||||||
|
scope.tables.append(table)
|
||||||
|
elif isinstance(table, exp.Subquery):
|
||||||
|
scope.subqueries.append(table)
|
||||||
|
_add_table_sources(scope)
|
||||||
|
_traverse_subqueries(scope)
|
||||||
|
|
||||||
|
|
||||||
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
def _traverse_derived_tables(derived_tables, scope, scope_type):
|
||||||
sources = {}
|
sources = {}
|
||||||
is_cte = scope_type == ScopeType.CTE
|
is_cte = scope_type == ScopeType.CTE
|
||||||
|
|
|
@ -194,6 +194,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.INTERVAL,
|
TokenType.INTERVAL,
|
||||||
TokenType.LAZY,
|
TokenType.LAZY,
|
||||||
TokenType.LEADING,
|
TokenType.LEADING,
|
||||||
|
TokenType.LEFT,
|
||||||
TokenType.LOCAL,
|
TokenType.LOCAL,
|
||||||
TokenType.MATERIALIZED,
|
TokenType.MATERIALIZED,
|
||||||
TokenType.MERGE,
|
TokenType.MERGE,
|
||||||
|
@ -208,6 +209,7 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.PRECEDING,
|
TokenType.PRECEDING,
|
||||||
TokenType.RANGE,
|
TokenType.RANGE,
|
||||||
TokenType.REFERENCES,
|
TokenType.REFERENCES,
|
||||||
|
TokenType.RIGHT,
|
||||||
TokenType.ROW,
|
TokenType.ROW,
|
||||||
TokenType.ROWS,
|
TokenType.ROWS,
|
||||||
TokenType.SCHEMA,
|
TokenType.SCHEMA,
|
||||||
|
@ -237,8 +239,10 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
|
||||||
TokenType.APPLY,
|
TokenType.APPLY,
|
||||||
|
TokenType.LEFT,
|
||||||
TokenType.NATURAL,
|
TokenType.NATURAL,
|
||||||
TokenType.OFFSET,
|
TokenType.OFFSET,
|
||||||
|
TokenType.RIGHT,
|
||||||
TokenType.WINDOW,
|
TokenType.WINDOW,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,6 +262,8 @@ class Parser(metaclass=_Parser):
|
||||||
TokenType.IDENTIFIER,
|
TokenType.IDENTIFIER,
|
||||||
TokenType.INDEX,
|
TokenType.INDEX,
|
||||||
TokenType.ISNULL,
|
TokenType.ISNULL,
|
||||||
|
TokenType.ILIKE,
|
||||||
|
TokenType.LIKE,
|
||||||
TokenType.MERGE,
|
TokenType.MERGE,
|
||||||
TokenType.OFFSET,
|
TokenType.OFFSET,
|
||||||
TokenType.PRIMARY_KEY,
|
TokenType.PRIMARY_KEY,
|
||||||
|
@ -971,13 +977,14 @@ class Parser(metaclass=_Parser):
|
||||||
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
|
||||||
this = self._parse_user_defined_function(kind=create_token.token_type)
|
this = self._parse_user_defined_function(kind=create_token.token_type)
|
||||||
properties = self._parse_properties()
|
properties = self._parse_properties()
|
||||||
if self._match(TokenType.ALIAS):
|
|
||||||
begin = self._match(TokenType.BEGIN)
|
|
||||||
return_ = self._match_text_seq("RETURN")
|
|
||||||
expression = self._parse_statement()
|
|
||||||
|
|
||||||
if return_:
|
self._match(TokenType.ALIAS)
|
||||||
expression = self.expression(exp.Return, this=expression)
|
begin = self._match(TokenType.BEGIN)
|
||||||
|
return_ = self._match_text_seq("RETURN")
|
||||||
|
expression = self._parse_statement()
|
||||||
|
|
||||||
|
if return_:
|
||||||
|
expression = self.expression(exp.Return, this=expression)
|
||||||
elif create_token.token_type == TokenType.INDEX:
|
elif create_token.token_type == TokenType.INDEX:
|
||||||
this = self._parse_index()
|
this = self._parse_index()
|
||||||
elif create_token.token_type in (
|
elif create_token.token_type in (
|
||||||
|
@ -2163,7 +2170,9 @@ class Parser(metaclass=_Parser):
|
||||||
) -> t.Optional[exp.Expression]:
|
) -> t.Optional[exp.Expression]:
|
||||||
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
if self._match(TokenType.TOP if top else TokenType.LIMIT):
|
||||||
limit_paren = self._match(TokenType.L_PAREN)
|
limit_paren = self._match(TokenType.L_PAREN)
|
||||||
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
|
limit_exp = self.expression(
|
||||||
|
exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
|
||||||
|
)
|
||||||
|
|
||||||
if limit_paren:
|
if limit_paren:
|
||||||
self._match_r_paren()
|
self._match_r_paren()
|
||||||
|
@ -2740,8 +2749,23 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
kind: exp.Expression
|
kind: exp.Expression
|
||||||
|
|
||||||
if self._match(TokenType.AUTO_INCREMENT):
|
if self._match_set((TokenType.AUTO_INCREMENT, TokenType.IDENTITY)):
|
||||||
kind = exp.AutoIncrementColumnConstraint()
|
start = None
|
||||||
|
increment = None
|
||||||
|
|
||||||
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
|
args = self._parse_wrapped_csv(self._parse_bitwise)
|
||||||
|
start = seq_get(args, 0)
|
||||||
|
increment = seq_get(args, 1)
|
||||||
|
elif self._match_text_seq("START"):
|
||||||
|
start = self._parse_bitwise()
|
||||||
|
self._match_text_seq("INCREMENT")
|
||||||
|
increment = self._parse_bitwise()
|
||||||
|
|
||||||
|
if start and increment:
|
||||||
|
kind = exp.GeneratedAsIdentityColumnConstraint(start=start, increment=increment)
|
||||||
|
else:
|
||||||
|
kind = exp.AutoIncrementColumnConstraint()
|
||||||
elif self._match(TokenType.CHECK):
|
elif self._match(TokenType.CHECK):
|
||||||
constraint = self._parse_wrapped(self._parse_conjunction)
|
constraint = self._parse_wrapped(self._parse_conjunction)
|
||||||
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
|
||||||
|
@ -3294,8 +3318,8 @@ class Parser(metaclass=_Parser):
|
||||||
if not self._match(TokenType.EXCEPT):
|
if not self._match(TokenType.EXCEPT):
|
||||||
return None
|
return None
|
||||||
if self._match(TokenType.L_PAREN, advance=False):
|
if self._match(TokenType.L_PAREN, advance=False):
|
||||||
return self._parse_wrapped_id_vars()
|
return self._parse_wrapped_csv(self._parse_column)
|
||||||
return self._parse_csv(self._parse_id_var)
|
return self._parse_csv(self._parse_column)
|
||||||
|
|
||||||
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
|
||||||
if not self._match(TokenType.REPLACE):
|
if not self._match(TokenType.REPLACE):
|
||||||
|
@ -3442,7 +3466,7 @@ class Parser(metaclass=_Parser):
|
||||||
|
|
||||||
def _parse_alter(self) -> t.Optional[exp.Expression]:
|
def _parse_alter(self) -> t.Optional[exp.Expression]:
|
||||||
if not self._match(TokenType.TABLE):
|
if not self._match(TokenType.TABLE):
|
||||||
return None
|
return self._parse_as_command(self._prev)
|
||||||
|
|
||||||
exists = self._parse_exists()
|
exists = self._parse_exists()
|
||||||
this = self._parse_table(schema=True)
|
this = self._parse_table(schema=True)
|
||||||
|
|
|
@ -357,7 +357,8 @@ class _Tokenizer(type):
|
||||||
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
|
||||||
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
|
||||||
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
|
||||||
klass._ESCAPES = set(klass.ESCAPES)
|
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
|
||||||
|
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
|
||||||
klass._COMMENTS = dict(
|
klass._COMMENTS = dict(
|
||||||
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
(comment, None) if isinstance(comment, str) else (comment[0], comment[1])
|
||||||
for comment in klass.COMMENTS
|
for comment in klass.COMMENTS
|
||||||
|
@ -429,9 +430,13 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
|
||||||
|
|
||||||
ESCAPES = ["'"]
|
STRING_ESCAPES = ["'"]
|
||||||
|
|
||||||
_ESCAPES: t.Set[str] = set()
|
_STRING_ESCAPES: t.Set[str] = set()
|
||||||
|
|
||||||
|
IDENTIFIER_ESCAPES = ['"']
|
||||||
|
|
||||||
|
_IDENTIFIER_ESCAPES: t.Set[str] = set()
|
||||||
|
|
||||||
KEYWORDS = {
|
KEYWORDS = {
|
||||||
**{
|
**{
|
||||||
|
@ -469,6 +474,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"ASC": TokenType.ASC,
|
"ASC": TokenType.ASC,
|
||||||
"AS": TokenType.ALIAS,
|
"AS": TokenType.ALIAS,
|
||||||
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
|
"AT TIME ZONE": TokenType.AT_TIME_ZONE,
|
||||||
|
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
|
||||||
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
|
||||||
"BEGIN": TokenType.BEGIN,
|
"BEGIN": TokenType.BEGIN,
|
||||||
"BETWEEN": TokenType.BETWEEN,
|
"BETWEEN": TokenType.BETWEEN,
|
||||||
|
@ -691,6 +697,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
"ALTER VIEW": TokenType.COMMAND,
|
"ALTER VIEW": TokenType.COMMAND,
|
||||||
"ANALYZE": TokenType.COMMAND,
|
"ANALYZE": TokenType.COMMAND,
|
||||||
"CALL": TokenType.COMMAND,
|
"CALL": TokenType.COMMAND,
|
||||||
|
"COPY": TokenType.COMMAND,
|
||||||
"EXPLAIN": TokenType.COMMAND,
|
"EXPLAIN": TokenType.COMMAND,
|
||||||
"OPTIMIZE": TokenType.COMMAND,
|
"OPTIMIZE": TokenType.COMMAND,
|
||||||
"PREPARE": TokenType.COMMAND,
|
"PREPARE": TokenType.COMMAND,
|
||||||
|
@ -744,7 +751,7 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._replace_backslash = "\\" in self._ESCAPES
|
self._replace_backslash = "\\" in self._STRING_ESCAPES
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
@ -1046,12 +1053,25 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _scan_identifier(self, identifier_end: str) -> None:
|
def _scan_identifier(self, identifier_end: str) -> None:
|
||||||
while self._peek != identifier_end:
|
text = ""
|
||||||
|
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
|
||||||
|
|
||||||
|
while True:
|
||||||
if self._end:
|
if self._end:
|
||||||
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
|
||||||
|
|
||||||
self._advance()
|
self._advance()
|
||||||
self._advance()
|
if self._char == identifier_end:
|
||||||
self._add(TokenType.IDENTIFIER, self._text[1:-1])
|
if identifier_end_is_escape and self._peek == identifier_end:
|
||||||
|
text += identifier_end # type: ignore
|
||||||
|
self._advance()
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
text += self._char # type: ignore
|
||||||
|
|
||||||
|
self._add(TokenType.IDENTIFIER, text)
|
||||||
|
|
||||||
def _scan_var(self) -> None:
|
def _scan_var(self) -> None:
|
||||||
while True:
|
while True:
|
||||||
|
@ -1072,9 +1092,9 @@ class Tokenizer(metaclass=_Tokenizer):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if (
|
if (
|
||||||
self._char in self._ESCAPES
|
self._char in self._STRING_ESCAPES
|
||||||
and self._peek
|
and self._peek
|
||||||
and (self._peek == delimiter or self._peek in self._ESCAPES)
|
and (self._peek == delimiter or self._peek in self._STRING_ESCAPES)
|
||||||
):
|
):
|
||||||
text += self._peek
|
text += self._peek
|
||||||
self._advance(2)
|
self._advance(2)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue